espnet2.asr_transducer.decoder.blocks.mega.MEGA
espnet2.asr_transducer.decoder.blocks.mega.MEGA
class espnet2.asr_transducer.decoder.blocks.mega.MEGA(size: int = 512, num_heads: int = 4, qk_size: int = 128, v_size: int = 1024, activation: ~torch.nn.modules.module.Module = ReLU(), normalization: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, rel_pos_bias_type: str = 'simple', max_positions: int = 2048, truncation_length: int | None = None, chunk_size: int = -1, dropout_rate: float = 0.0, att_dropout_rate: float = 0.0, ema_dropout_rate: float = 0.0)
Bases: Module
MEGA module.
- Parameters:
- size β Input/Output size.
- num_heads β Number of EMA heads.
- qk_size β Shared query and key size for attention module.
- v_size β Value size for attention module.
- qk_v_size β (QK, V) sizes for attention module.
- activation β Activation function type.
- normalization β Normalization module.
- rel_pos_bias_type β Type of relative position bias in attention module.
- max_positions β Maximum number of position for RelativePositionBias.
- truncation_length β Maximum length for truncation in EMA module.
- chunk_size β Chunk size for attention computation (-1 = full context).
- dropout_rate β Dropout rate for inner modules.
- att_dropout_rate β Dropout rate for the attention module.
- ema_dropout_rate β Dropout rate for the EMA module.
Construct a MEGA object.
forward(x: Tensor, mask: Tensor | None = None, attn_mask: Tensor | None = None, state: Dict[str, Tensor | None] | None = None) β Tuple[Tensor, Dict[str, Tensor | None] | None]
Compute moving average equiped gated attention.
- Parameters:
- x β MEGA input sequences. (L, B, size)
- mask β MEGA input sequence masks. (B, 1, L)
- attn_mask β MEGA attention mask. (1, L, L)
- state β Decoder hidden states.
- Returns: MEGA output sequences. (B, L, size) state: Decoder hidden states.
- Return type: x
reset_parameters(val: int = 0.0, std: int = 0.02) β None
Reset module parameters.
- Parameters:
- val β Initialization value.
- std β Standard deviation.
softmax_attention(query: Tensor, key: Tensor, mask: Tensor | None = None, attn_mask: Tensor | None = None) β Tensor
Compute attention weights with softmax.
- Parameters:
- query β Query tensor. (B, 1, L, D)
- key β Key tensor. (B, 1, L, D)
- mask β Sequence mask. (B, 1, L)
- attn_mask β Attention mask. (1, L, L)
- Returns: Attention weights. (B, 1, L, L)
- Return type: attn_weights
