espnet2.asr.encoder.branchformer_encoder.BranchformerEncoderLayer
espnet2.asr.encoder.branchformer_encoder.BranchformerEncoderLayer
class espnet2.asr.encoder.branchformer_encoder.BranchformerEncoderLayer(size: int, attn: Module | None, cgmlp: Module | None, dropout_rate: float, merge_method: str, cgmlp_weight: float = 0.5, attn_branch_drop_rate: float = 0.0, stochastic_depth_rate: float = 0.0)
Bases: Module
Branchformer encoder layer module.
This class implements a single layer of the Branchformer encoder, which utilizes both multi-headed self-attention and Convolutional Gating MLP (CGMLP) branches to capture local and global context in speech recognition tasks. The output of the two branches can be merged using different methods such as concatenation, learned average, or fixed average.
- Parameters:
- size (int) – The model dimension.
- attn (Optional *[*torch.nn.Module ]) – The self-attention module to use.
- cgmlp (Optional *[*torch.nn.Module ]) – The CGMLP module to use.
- dropout_rate (float) – The dropout probability for the layers.
- merge_method (str) – The method to merge outputs from branches. Options: ‘concat’, ‘learned_ave’, or ‘fixed_ave’.
- cgmlp_weight (float) – Weight of the CGMLP branch (0 to 1) used in ‘fixed_ave’ merge method.
- attn_branch_drop_rate (float) – Dropout probability for the attention branch used in ‘learned_ave’ merge method.
- stochastic_depth_rate (float) – Probability of applying stochastic depth to the layer.
size
The model dimension.
- Type: int
attn
The self-attention module.
- Type: Optional[torch.nn.Module]
cgmlp
The CGMLP module.
- Type: Optional[torch.nn.Module]
merge_method
The method used to merge outputs.
- Type: str
cgmlp
Weight of the CGMLP branch.
- Type: float
attn
Dropout rate for the attention branch.
- Type: float
stochastic_depth_rate
Stochastic depth probability.
- Type: float
use_two_branches
Flag indicating if both branches are used.
- Type: bool
norm_mha
Layer normalization for the MHA module.
- Type:LayerNorm
norm_mlp
Layer normalization for the MLP module.
- Type:LayerNorm
norm_final
Layer normalization for the final output.
- Type:LayerNorm
dropout
Dropout layer.
- Type: Dropout
merge_proj
Projection layer for merging outputs.
Type: torch.nn.Module
Raises:ValueError – If an unknown merge method is provided or if the cgmlp_weight is not in the range [0, 1].
Examples
>>> layer = BranchformerEncoderLayer(
... size=256,
... attn=MultiHeadedAttention(4, 256),
... cgmlp=ConvolutionalGatingMLP(256, 2048, 31),
... dropout_rate=0.1,
... merge_method='learned_ave',
... cgmlp_weight=0.5,
... attn_branch_drop_rate=0.2,
... stochastic_depth_rate=0.1
... )
>>> x_input = torch.randn(32, 10, 256) # (batch_size, seq_len, size)
>>> mask = torch.ones(32, 1, 10) # (batch_size, 1, seq_len)
>>> output, output_mask = layer(x_input, mask)
NOTE
This implementation includes support for stochastic depth, which can be beneficial for regularizing deep networks by randomly skipping layers during training.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x_input, mask, cache=None)
Compute encoded features.
- Parameters:
- x_input (Union *[*Tuple , torch.Tensor ]) – Input tensor w/ or w/o pos emb.
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
- w/o pos emb: Tensor (#batch, time, size).
- mask (torch.Tensor) – Mask tensor for the input (#batch, 1, time).
- cache (torch.Tensor) – Cache tensor of the input (#batch, time - 1, size).
- x_input (Union *[*Tuple , torch.Tensor ]) – Input tensor w/ or w/o pos emb.
- Returns: Output tensor (#batch, time, size). torch.Tensor: Mask tensor (#batch, time).
- Return type: torch.Tensor