espnet2.asr.encoder.beats_encoder.MultiheadAttention
espnet2.asr.encoder.beats_encoder.MultiheadAttention
class espnet2.asr.encoder.beats_encoder.MultiheadAttention(embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0.0, qn_block_size=8, has_relative_attention_bias=False, num_buckets=32, max_distance=128, gru_rel_pos=False, rescale_init=False)
Bases: Module
Multi-headed attention mechanism.
This module implements the multi-headed attention mechanism as described in the paper “Attention Is All You Need”. It allows the model to focus on different parts of the input sequence when generating the output.
- Parameters:
- embed_dim (int) – Total dimension of the model.
- num_heads (int) – Number of attention heads.
- kdim (int , optional) – Total dimension of the keys. Defaults to embed_dim.
- vdim (int , optional) – Total dimension of the values. Defaults to embed_dim.
- dropout (float , optional) – Dropout probability. Defaults to 0.0.
- bias (bool , optional) – Whether to include bias in the linear projections. Defaults to True.
- add_bias_kv (bool , optional) – If True, adds bias to the key and value. Defaults to False.
- add_zero_attn (bool , optional) – If True, adds a new attention head that attends to zero vectors. Defaults to False.
- self_attention (bool , optional) – If True, enables self-attention mode. Defaults to False.
- encoder_decoder_attention (bool , optional) – If True, enables attention from encoder to decoder. Defaults to False.
- q_noise (float , optional) – Amount of quantization noise. Defaults to 0.0.
- qn_block_size (int , optional) – Size of the blocks for quantization noise. Defaults to 8.
- has_relative_attention_bias (bool , optional) – If True, enables relative attention bias. Defaults to False.
- num_buckets (int , optional) – Number of buckets for relative attention. Defaults to 32.
- max_distance (int , optional) – Maximum distance for relative attention. Defaults to 128.
- gru_rel_pos (bool , optional) – If True, enables GRU-based relative position encoding. Defaults to False.
- rescale_init (bool , optional) – If True, enables rescaling initialization for the weights. Defaults to False.
forward(query, key, value, key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False, attn_mask=None, before_softmax=False, need_head_weights=False, position_bias=None):
Performs the forward pass of the multi-head attention.
############# Examples
>>> attention = MultiheadAttention(embed_dim=512, num_heads=8)
>>> query = torch.rand(10, 32, 512) # (sequence_length, batch_size, embed_dim)
>>> key = torch.rand(10, 32, 512)
>>> value = torch.rand(10, 32, 512)
>>> output, attn_weights, _ = attention(query, key, value)
######### NOTE Ensure that the embed_dim is divisible by num_heads to avoid errors during computation.
- Raises:AssertionError – If embed_dim is not divisible by num_heads.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int)
Apply a sparse mask to the attention weights.
This method is intended to be a placeholder for potential future implementations of sparse masking techniques. Currently, it does not modify the attention weights.
- Parameters:
- attn_weights (torch.Tensor) – The raw attention weights with shape (bsz * num_heads, tgt_len, src_len).
- tgt_len (int) – The length of the target sequence.
- src_len (int) – The length of the source sequence.
- bsz (int) – The batch size.
- Returns: The (unchanged) attention weights with the same shape : as the input attn_weights.
- Return type: torch.Tensor
######### NOTE This function is a no-op and returns the input weights as is. It can be extended in the future to implement actual sparse masking logic.
############# Examples
>>> attn_weights = torch.rand(2, 5, 10) # Example tensor
>>> sparse_masked_weights = apply_sparse_mask(attn_weights, 5, 10, 2)
>>> assert torch.equal(attn_weights, sparse_masked_weights)
compute_bias(query_length, key_length)
Compute relative position bias.
This method calculates the relative position bias used in multi-headed attention mechanisms. It generates a bias tensor based on the relative positions of the query and key sequences. The bias is computed using the relative position buckets and the learned relative attention bias parameters.
- Parameters:
- query_length (int) – The length of the query sequence.
- key_length (int) – The length of the key sequence.
- Returns: A tensor of shape (num_heads, query_length, key_length) representing the computed relative position bias.
- Return type: torch.Tensor
############# Examples
>>> attention = MultiheadAttention(embed_dim=512, num_heads=8)
>>> bias = attention.compute_bias(query_length=10, key_length=15)
>>> print(bias.shape)
torch.Size([8, 10, 15])
######### NOTE This method requires that self.relative_attention_bias is initialized with the appropriate parameters.
forward(query, key: Tensor | None, value: Tensor | None, key_padding_mask: Tensor | None = None, incremental_state: Dict[str, Dict[str, Tensor | None]] | None = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Tensor | None = None, before_softmax: bool = False, need_head_weights: bool = False, position_bias: Tensor | None = None) → Tuple[Tensor, Tensor | None, Tensor | None]
Forward pass for the Beats encoder.
This method processes the input audio features and returns the audio representation along with the output lengths and any masks. It acts as a wrapper for compatibility with the ESPnet’s AbsEncoder interface.
Parameters:
- xs_pad (torch.Tensor) – A tensor of shape (B, T, D) representing the padded audio features, where B is the batch size, T is the sequence length, and D is the feature dimension.
- ilens (torch.Tensor) – A tensor of shape (B,) representing the actual lengths of each sequence in the batch.
- prev_states (torch.Tensor , optional) – A tensor containing the previous states. Defaults to None.
Returns:
- audio_representation (torch.Tensor): A tensor of shape (B, T, D)
representing the processed audio features.
- output_lens (torch.Tensor): A tensor of shape (B,) containing the output lengths for each sequence in the batch.
- masks (Optional[torch.Tensor]): A tensor for masks. Defaults to None.
Return type: Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
######### NOTE If xs_pad is not provided, the operation can be costly since this function attempts to create a tensor of size maxlen x maxlen. To mitigate this, the input tensor is unsqueezed and then squeezed.
############# Examples
>>> encoder = BeatsEncoder(...)
>>> xs_pad = torch.randn(2, 10, 512) # Batch of 2, sequence length 10, features 512
>>> ilens = torch.tensor([10, 8]) # Actual lengths for each sequence
>>> audio_rep, output_lengths, masks = encoder.forward(xs_pad, ilens)
- Raises:ValueError – If the input tensor dimensions are not as expected.
reset_parameters()
Initiate parameters in the transformer model.
This method initializes the weights of the MultiheadAttention module and its components using a scaled Xavier uniform distribution. It is designed to improve the convergence behavior of the model during training.
The initialization process includes the following steps:
- Initializes the weights of the query, key, and value projection layers (k_proj, v_proj, q_proj) using Xavier uniform initialization.
- Initializes the output projection layer (out_proj) weights using Xavier uniform initialization.
- Initializes the bias terms (bias_k and bias_v) to zero, if they are defined.
- If relative attention bias is used, initializes the corresponding weights.
Logging information is provided to indicate that the parameters have been initiated.
############# Examples
>>> attention_layer = MultiheadAttention(embed_dim=512, num_heads=8)
>>> attention_layer.reset_parameters()
######### NOTE This method should be called after creating an instance of the MultiheadAttention class to ensure that the parameters are set before training.
- Raises:
- RuntimeError – If the embedding dimension is not divisible by the
- number of attention heads. –