espnet2.asr_transducer.decoder.blocks.mega.MEGA
espnet2.asr_transducer.decoder.blocks.mega.MEGA
class espnet2.asr_transducer.decoder.blocks.mega.MEGA(size: int = 512, num_heads: int = 4, qk_size: int = 128, v_size: int = 1024, activation: ~torch.nn.modules.module.Module = ReLU(), normalization: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, rel_pos_bias_type: str = 'simple', max_positions: int = 2048, truncation_length: int | None = None, chunk_size: int = -1, dropout_rate: float = 0.0, att_dropout_rate: float = 0.0, ema_dropout_rate: float = 0.0)
Bases: Module
Moving Average Equipped Gated Attention (MEGA) block definition.
Based/modified from https://github.com/facebookresearch/mega/blob/main/fairseq/modules/ moving_average_gated_attention.py
Most variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/ mega/modeling_mega.py.
multihead_damped_ema
Multi-head damped exponential moving average module.
rel_pos_bias
Relative position bias module based on the specified type.
proj_v
Linear projection layer for value tensor.
proj_mx
Linear projection layer for combined inputs.
proj_h
Linear projection layer for output tensor.
qk_weight
Learnable parameter for query/key weights.
qk_bias
Learnable parameter for query/key biases.
scaling
Scaling factor for query tensor.
activation
Activation function applied within the module.
normalization
Normalization function applied at the output.
dropout
Dropout layer for general use.
dropout
Dropout layer specifically for attention.
dropout
Dropout layer specifically for EMA.
- Parameters:
- size (int) – Input/Output size.
- num_heads (int) – Number of EMA heads.
- qk_size (int) – Shared query and key size for attention module.
- v_size (int) – Value size for attention module.
- activation (torch.nn.Module) – Activation function type.
- normalization (torch.nn.Module) – Normalization module.
- rel_pos_bias_type (str) – Type of relative position bias in attention module.
- max_positions (int) – Maximum number of positions for RelativePositionBias.
- truncation_length (Optional *[*int ]) – Maximum length for truncation in EMA module.
- chunk_size (int) – Chunk size for attention computation (-1 = full context).
- dropout_rate (float) – Dropout rate for inner modules.
- att_dropout_rate (float) – Dropout rate for the attention module.
- ema_dropout_rate (float) – Dropout rate for the EMA module.
########### Examples
>>> mega = MEGA(size=512, num_heads=4)
>>> input_tensor = torch.randn(10, 32, 512) # (L, B, size)
>>> output, state = mega(input_tensor)
- Raises:ValueError – If an invalid value is provided for rel_pos_bias_type.
Construct a MEGA object.
forward(x: Tensor, mask: Tensor | None = None, attn_mask: Tensor | None = None, state: Dict[str, Tensor | None] | None = None) → Tuple[Tensor, Dict[str, Tensor | None] | None]
Moving Average Equipped Gated Attention (MEGA) block definition.
Based/modified from https://github.com/facebookresearch/mega/blob/main/fairseq/modules/moving_average_gated_attention.py
Most variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/mega/modeling_mega.py.
multihead_damped_ema
Multi-head damped EMA module.
rel_pos_bias
Relative position bias module.
proj_v
Linear layer for value projection.
proj_mx
Linear layer for mixed projection.
proj_h
Linear layer for output projection.
qk_weight
Learnable parameters for query-key weights.
qk_bias
Learnable parameters for query-key biases.
scaling
Scaling factor for query-key attention.
activation
Activation function used in the module.
normalization
Normalization module applied to the output.
dropout
Dropout layer for inner modules.
dropout
Dropout layer for attention module.
dropout
Dropout layer for EMA module.
qk_size
Size of the query-key vectors.
v_size
Size of the value vectors.
size
Input/output size.
chunk_size
Size of the chunks for attention computation.
- Parameters:
- size – Input/Output size.
- num_heads – Number of EMA heads.
- qk_size – Shared query and key size for attention module.
- v_size – Value size for attention module.
- activation – Activation function type (default: ReLU).
- normalization – Normalization module (default: LayerNorm).
- rel_pos_bias_type – Type of relative position bias in attention module.
- max_positions – Maximum number of positions for RelativePositionBias.
- truncation_length – Maximum length for truncation in EMA module.
- chunk_size – Chunk size for attention computation (-1 = full context).
- dropout_rate – Dropout rate for inner modules.
- att_dropout_rate – Dropout rate for the attention module.
- ema_dropout_rate – Dropout rate for the EMA module.
########### Examples
>>> mega = MEGA(size=512, num_heads=4)
>>> input_tensor = torch.randn(10, 32, 512) # (L, B, size)
>>> output, state = mega(input_tensor)
- Raises:ValueError – If rel_pos_bias_type is not ‘rotary’ or ‘simple’.
reset_parameters(val: int = 0.0, std: int = 0.02) → None
Reset module parameters.
This method initializes the weights and biases of the various linear layers and parameters within the MEGA module. The initialization is done using a normal distribution for weights and a constant value for biases.
- Parameters:
- val – Initialization value for biases. Default is 0.0.
- std – Standard deviation for the normal distribution used to initialize weights. Default is 0.02.
########### Examples
>>> mega = MEGA()
>>> mega.reset_parameters(val=0.1, std=0.01)
>>> # This will reset the parameters with bias initialized to 0.1
>>> # and weights initialized from a normal distribution with std
>>> # deviation of 0.01.
NOTE
This method should be called to reinitialize the parameters if needed, such as when retraining or experimenting with different initialization strategies.
softmax_attention(query: Tensor, key: Tensor, mask: Tensor | None = None, attn_mask: Tensor | None = None) → Tensor
Moving Average Equipped Gated Attention (MEGA) block definition.
Based/modified from https://github.com/facebookresearch/mega/blob/main/fairseq/modules/moving_average_gated_attention.py
Most variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/mega/modeling_mega.py.
multihead_damped_ema
Multi-head damped EMA module.
rel_pos_bias
Relative position bias module.
proj_v
Linear projection for value tensor.
proj_mx
Linear projection for mixed tensor.
proj_h
Linear projection for hidden state.
qk_weight
Parameter for query-key weight.
qk_bias
Parameter for query-key bias.
scaling
Scaling factor for query tensor.
activation
Activation function used in the module.
normalization
Normalization module applied to output.
dropout
Dropout layer for inner modules.
dropout
Dropout layer for attention module.
dropout
Dropout layer for EMA module.
qk_size
Size of the query-key.
v_size
Size of the value.
size
Input/Output size.
chunk_size
Chunk size for attention computation.
- Parameters:
- size – Input/Output size.
- num_heads – Number of EMA heads.
- qk_size – Shared query and key size for attention module.
- v_size – Value size for attention module.
- activation – Activation function type.
- normalization – Normalization module.
- rel_pos_bias_type – Type of relative position bias in attention module.
- max_positions – Maximum number of positions for RelativePositionBias.
- truncation_length – Maximum length for truncation in EMA module.
- chunk_size – Chunk size for attention computation (-1 = full context).
- dropout_rate – Dropout rate for inner modules.
- att_dropout_rate – Dropout rate for the attention module.
- ema_dropout_rate – Dropout rate for the EMA module.
########### Examples
>>> mega = MEGA(size=512, num_heads=4)
>>> query = torch.randn(2, 1, 10, 128)
>>> key = torch.randn(2, 1, 10, 128)
>>> attn_weights = mega.softmax_attention(query, key)
NOTE
The MEGA module is designed for efficient attention computation with moving averages, making it suitable for tasks in natural language processing and other sequential data applications.