espnet2.asr.encoder.beats_encoder.BeatsEncoder
espnet2.asr.encoder.beats_encoder.BeatsEncoder
class espnet2.asr.encoder.beats_encoder.BeatsEncoder(input_size: int, beats_ckpt_path: str | None = None, max_layer: int | None = None, downsampling_rate: int = 1, adapter_config: str = '', use_weighted_representation: bool = False, beats_config: BeatsConfig | None = None, specaug_config: Dict | None = None, add_positional_information: bool = False, max_positions: int | None = None)
Bases: AbsEncoder
BEATs: Audio Pre-Training with Acoustic Tokenizers.
This class implements the BEATs model for audio pre-training and fine-tuning using acoustic tokenizers. It can handle various configurations, including the use of pretrained weights and the application of SpecAugment.
fbank_mean
Mean of the filter banks.
- Type: float
fbank_std
Standard deviation of the filter banks.
- Type: float
max_layer
Maximum layer to propagate input through.
- Type: Optional[int]
beats_ckpt_path
Path to a pretrained Beats checkpoint.
- Type: Optional[str]
loaded_state_dict_
Loaded state dictionary from the checkpoint.
- Type: Optional[Dict]
specaug
SpecAugment instance if config provided.
- Type: Optional[SpecAug]
_output_size
Size of the output features.
- Type: int
embed
Embedding dimension.
- Type: int
input_patch_size
Size of input patches for the model.
- Type: int
post_extract_proj
Projection layer after feature extraction.
- Type: Optional[nn.Linear]
patch_embedding
Convolutional layer for patch embedding.
- Type: nn.Conv2d
dropout_input
Dropout layer for input features.
- Type: nn.Dropout
encoder
Transformer encoder module.
- Type:TransformerEncoder
layer_norm
Layer normalization module.
- Type:LayerNorm
use_weighted_representation
Flag to use weighted representations.
- Type: bool
layer_weights
Weights for layer representations if using weighted representations.
- Type: Optional[nn.Parameter]
downsample_conv
Downsampling convolutional layer.
- Type: Optional[nn.Conv1d]
conformer_adapter
Adapter module for Wav2Vec2.
- Type: Optional[Wav2Vec2ConformerEncoder]
cross_embed_positions
Learned positional embeddings for cross-attention.
Type: Optional[BartLearnedPositionalEmbedding]
Parameters:
- input_size (int) – The size of the input features.
- beats_ckpt_path (str , optional) – Path to a pretrained Beats checkpoint. If beats_config is provided and it does not match the config in the checkpoint, an error might occur.
- max_layer (int , optional) – Maximum layer to propagate input through. If None, input is propagated through all layers.
- downsampling_rate (int , optional) – Downsampling rate for the encoder. Applied if greater than 1. Default is 1.
- adapter_config (str , optional) – Path to a config file for the Wav2Vec2 adapter.
- use_weighted_representation (bool , optional) – If True, use weighted representations from max_layer. Weights are randomly initialized.
- beats_config (Optional [BeatsConfig ] , optional) – BeatsConfig object. If provided, will attempt to override the config in the checkpoint.
- specaug_config (Optional *[*Dict ] , optional) – Dictionary containing parameters for SpecAugment. If provided, SpecAugment will be applied.
- add_positional_information (bool , optional) – If True, add learned positional embeddings.
- max_positions (Optional *[*int ] , optional) – Maximum number of positions for positional embeddings. Required if add_positional_information is True.
Raises:ImportError – If the transformers library is not available and adapter_config or add_positional_information is set.
################# Examples
>>> encoder = BeatsEncoder(input_size=128, beats_ckpt_path='path/to/ckpt')
>>> features = torch.randn(10, 16000) # 10 audio samples
>>> ilens = torch.tensor([16000] * 10) # Lengths of each sample
>>> audio_representation, output_lens, _ = encoder(features, ilens)
########### NOTE This class is designed to be compatible with the ESPnet framework’s AbsEncoder interface.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Extract features from raw audio.
This method processes the input audio tensor and extracts meaningful features using a series of transformations, including patch embedding, layer normalization, and optional downsampling. The resulting features can be used for further processing or modeling tasks.
- Parameters:
- source (torch.Tensor) – A tensor of shape (B, T) representing the input audio, where B is the batch size and T is the number of time steps.
- padding_mask (Optional *[*torch.Tensor ]) – An optional mask tensor of shape (B, T) indicating the positions of the padding tokens in the input. Default is None.
- max_layer (Optional *[*int ]) – If specified, this determines the maximum layer from which features should be extracted. If None, features from all layers will be returned.
- Returns:
- torch.Tensor: The extracted features of shape (B, C, T), where C is the number of channels (features).
- Optional[torch.Tensor]: The updated padding mask tensor after processing, or None if padding_mask was not provided.
- Return type: Tuple[torch.Tensor, Optional[torch.Tensor]]
################# Examples
>>> encoder = BeatsEncoder(...)
>>> audio_input = torch.randn(4, 16000) # Batch of 4 audio samples
>>> features, updated_mask = encoder.extract_features(audio_input)
########### NOTE If SpecAugment is enabled during training, the input features will be augmented accordingly before feature extraction.
- Raises:
- ValueError – If the input tensor is not of the expected shape
- or if any of the layers in the encoder are misconfigured. –
forward(xs_pad: Tensor, ilens: Tensor, prev_states: Tensor | None = None) → Tuple[Tensor, Tensor, Tensor | None]
Wrapper for compatibility with ESPnet’s AbsEncoder Interface.
This method processes the input tensor and computes audio representations by applying the Beats encoder. It manages padding and length adjustments for batch processing.
- Parameters:
- xs_pad (torch.Tensor) – Input tensor of shape (B, T, D) where B is the batch size, T is the sequence length, and D is the feature dimension.
- ilens (torch.Tensor) – Tensor of shape (B,) containing the lengths of each sequence in the batch.
- prev_states (torch.Tensor , optional) – Not used in this implementation. Defaults to None.
- Returns:
- audio_representation (torch.Tensor): The output audio representation tensor of shape (B, T, D).
- output_lens (torch.Tensor): Tensor of shape (B,) containing the lengths of the output sequences.
- masks (Optional[torch.Tensor]): Currently set to None.
- Return type: Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
########### NOTE If xs_pad is not provided, this operation can be costly because it attempts to create a tensor of size maxlen x maxlen. Therefore, the implementation unsqueezes and squeezes tensors to optimize performance.
################# Examples
>>> encoder = BeatsEncoder(...)
>>> input_tensor = torch.randn(4, 100, 64) # (B, T, D)
>>> input_lengths = torch.tensor([100, 90, 80, 70])
>>> audio_rep, output_lengths, masks = encoder.forward(input_tensor,
... input_lengths)
>>> print(audio_rep.shape) # Should be (4, T', D)
>>> print(output_lengths) # Should be tensor of lengths
forward_padding_mask(features: Tensor, padding_mask: Tensor) → Tensor
Generate a forward padding mask based on input features.
This method processes the provided padding mask to ensure it is compatible with the dimensions of the input features. The function adjusts the padding mask’s size to match the features by removing any extra padding and reshaping it accordingly. The resulting mask indicates which parts of the input are valid (not padded).
- Parameters:
- features (torch.Tensor) – A tensor representing input features with shape (B, T, C), where B is the batch size, T is the sequence length, and C is the number of features per time step.
- padding_mask (torch.Tensor) – A tensor representing the original padding mask with shape (B, L), where L is the length of the sequence before any adjustments.
- Returns: A boolean tensor of shape (B, T) indicating the valid positions in the input features after padding adjustment.
- Return type: torch.Tensor
################# Examples
>>> features = torch.randn(4, 10, 512) # Batch of 4, 10 time steps
>>> padding_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
... [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
... [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])
>>> mask = forward_padding_mask(features, padding_mask)
>>> print(mask.shape)
torch.Size([4, 10])
########### NOTE The padding mask is expected to be in the shape of (B, L), where L is typically the maximum sequence length used in the batch. The function will truncate the padding mask if necessary to fit the features tensor.
- Raises:
- ValueError – If the padding_mask does not have the expected shape
- or if the dimensions are incompatible. –
output_size() → int
Get the output size of the BeatsEncoder.
This function retrieves the output size of the encoder, which is determined during the initialization based on the configuration provided to the Beats model.
- Returns: The output size of the encoder, typically equal to the encoder embedding dimension defined in the configuration.
- Return type: int
################# Examples
>>> encoder = BeatsEncoder(input_size=256)
>>> size = encoder.output_size()
>>> print(size)
768 # Assuming the encoder embedding dimension is set to 768
########### NOTE The output size is essential for determining the shape of the data that flows through subsequent layers of the model.
preprocess(source: Tensor) → Tensor
Preprocess raw audio into feature representations.
This method takes raw audio waveforms and converts them into filter bank features suitable for input into the BEATs model. Each waveform is processed to extract Mel filter bank features, which are then normalized using pre-defined mean and standard deviation values.
- Parameters:source (torch.Tensor) – A tensor of shape (B, T) where B is the batch size and T is the number of time steps (samples) in each waveform.
- Returns: A tensor of shape (B, F, T’) where F is the number : of Mel filter bank coefficients (128) and T’ is the number of frames obtained from the original audio after processing.
- Return type: torch.Tensor
################# Examples
>>> encoder = BeatsEncoder()
>>> raw_audio = torch.randn(2, 16000) # Example: 2 audio samples
>>> features = encoder.preprocess(raw_audio)
>>> print(features.shape) # Output: (2, 128, T')
########### NOTE
- The input waveforms are expected to be in float32 format, and the function scales them to int16 format during processing.
- The filter bank extraction is performed using Kaldi’s fbank function with a frame length of 25 ms and a frame shift of 10 ms.
reload_pretrained_parameters()
Initialize the Beats model parameters.
This method is intended to be called last in the initialization procedure. It performs the following steps:
- Initializes the Beats encoder parameters.
- If a pretrained checkpoint is provided, loads the weights from the checkpoint to override the initialized parameters.
The initialization includes:
- Applying Xavier normal initialization to the post-extraction
projection layer (if it exists).
- Applying Xavier normal initialization to the patch embedding layer.
- Calling the custom weight initialization for the encoder layers.
If a pretrained model state is loaded, it also logs any missing or unexpected keys between the loaded model and the custom model.
- Raises:
- RuntimeError – If the pretrained weights do not match the
- model architecture. –
################# Examples
>>> encoder = BeatsEncoder(...)
>>> encoder.reload_pretrained_parameters()
# This will initialize the parameters and load the pretrained
# weights if available.
########### NOTE Ensure that this method is called after all model layers have been initialized to avoid any inconsistencies.