espnet2.asr_transducer.encoder.blocks.ebranchformer.EBranchformer
espnet2.asr_transducer.encoder.blocks.ebranchformer.EBranchformer
class espnet2.asr_transducer.encoder.blocks.ebranchformer.EBranchformer(block_size: int, linear_size: int, self_att: ~torch.nn.modules.module.Module, feed_forward: ~torch.nn.modules.module.Module, feed_forward_macaron: ~torch.nn.modules.module.Module, conv_mod: ~torch.nn.modules.module.Module, depthwise_conv_mod: ~torch.nn.modules.module.Module, norm_class: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, norm_args: ~typing.Dict = {}, dropout_rate: float = 0.0)
Bases: Module
E-Branchformer block for Transducer encoder.
This class implements the E-Branchformer module, which is a crucial component of the Transducer encoder architecture. It incorporates self-attention, feed-forward networks, and convolutional layers to process input sequences effectively.
Reference: https://arxiv.org/pdf/2210.00077.pdf
self_att
Instance of the self-attention module.
- Type: torch.nn.Module
feed_forward
Instance of the feed-forward module.
- Type: torch.nn.Module
feed_forward
Instance of the macaron feed-forward module.
- Type: torch.nn.Module
conv_mod
Instance of the ConvolutionalSpatialGatingUnit module.
- Type: torch.nn.Module
depthwise_conv_mod
Instance of the DepthwiseConvolution module.
- Type: torch.nn.Module
norm_self_att
Normalization layer for self-attention.
- Type: torch.nn.Module
norm_feed_forward
Normalization layer for feed-forward network.
- Type: torch.nn.Module
norm_feed_forward
Normalization layer for macaron feed-forward.
- Type: torch.nn.Module
norm_mlp
Normalization layer for MLP.
- Type: torch.nn.Module
norm_final
Final normalization layer.
- Type: torch.nn.Module
dropout
Dropout layer for regularization.
- Type: torch.nn.Dropout
block_size
Input/output size of the block.
- Type: int
linear_size
Hidden size of the linear layers.
- Type: int
cache
Cache for self-attention and convolution.
Type: Optional[List[torch.Tensor]]
Parameters:
- block_size (int) – Input/output size.
- linear_size (int) – Linear layers’ hidden size.
- self_att (torch.nn.Module) – Self-attention module instance.
- feed_forward (torch.nn.Module) – Feed-forward module instance.
- feed_forward_macaron (torch.nn.Module) – Feed-forward module instance for macaron network.
- conv_mod (torch.nn.Module) – ConvolutionalSpatialGatingUnit module instance.
- depthwise_conv_mod (torch.nn.Module) – DepthwiseConvolution module instance.
- norm_class (torch.nn.Module , optional) – Normalization class (default: LayerNorm).
- norm_args (Dict , optional) – Normalization module arguments (default: {}).
- dropout_rate (float , optional) – Dropout rate (default: 0.0).
########### Examples
Create an E-Branchformer instance
e_branchformer = EBranchformer(
block_size=128, linear_size=256, self_att=my_self_att_module, feed_forward=my_feed_forward_module, feed_forward_macaron=my_macaron_feed_forward_module, conv_mod=my_conv_module, depthwise_conv_mod=my_depthwise_conv_module, norm_class=torch.nn.LayerNorm, norm_args={‘eps’: 1e-5}, dropout_rate=0.1
)
Forward pass through the E-Branchformer
output, mask, pos_enc = e_branchformer(input_tensor, pos_enc_tensor, mask_tensor)
Resetting the streaming cache
e_branchformer.reset_streaming_cache(left_context=10, device=torch.device(‘cuda’))
####### NOTE The module requires specific instances of self-attention, feed-forward, convolution, and normalization modules. These modules should be defined prior to instantiation of the EBranchformer class.
Construct a E-Branchformer object.
chunk_forward(x: Tensor, pos_enc: Tensor, mask: Tensor, left_context: int = 0) → Tuple[Tensor, Tensor]
Encode chunk of input sequence.
This method processes a chunk of the input sequence through the E-Branchformer architecture, incorporating self-attention and feed-forward mechanisms while considering the specified left context for attention.
- Parameters:
- x – E-Branchformer input sequences. Shape: (B, T, D_block), where B is the batch size, T is the sequence length, and D_block is the dimensionality of the input features.
- pos_enc – Positional embedding sequences. Shape: (B, 2 * (T - 1), D_block).
- mask – Source mask. Shape: (B, T_2), used to prevent attention to certain positions in the input.
- left_context – Number of previous frames the attention module can see in the current chunk. Defaults to 0.
- Returns:
- x: E-Branchformer output sequences. Shape: (B, T, D_block).
- pos_enc: Positional embedding sequences. Shape: (B, 2 * (T - 1), D_block).
- Return type: Tuple[torch.Tensor, torch.Tensor]
########### Examples
>>> e_branchformer = EBranchformer(block_size=128, linear_size=256,
... self_att=my_self_att,
... feed_forward=my_feed_forward,
... feed_forward_macaron=my_feed_forward_macaron,
... conv_mod=my_conv_mod,
... depthwise_conv_mod=my_depthwise_conv_mod)
>>> x = torch.randn(10, 20, 128) # Batch of 10, sequence length 20
>>> pos_enc = torch.randn(10, 38, 128) # Positional encoding for T-1
>>> mask = torch.ones(10, 20) # Full attention mask
>>> output, pos_enc_out = e_branchformer.chunk_forward(x, pos_enc, mask)
####### NOTE This method maintains a cache for self-attention and convolutional layers to optimize processing of sequential data.
forward(x: Tensor, pos_enc: Tensor, mask: Tensor, chunk_mask: Tensor | None = None) → Tuple[Tensor, Tensor, Tensor]
Encode input sequences using the E-Branchformer module.
The forward method processes input sequences through various layers, including self-attention and feed-forward layers, to produce the output sequences. The method also supports masking for attention mechanisms and incorporates residual connections for better gradient flow.
- Parameters:
- x (torch.Tensor) – E-Branchformer input sequences of shape (B, T, D_block), where B is the batch size, T is the sequence length, and D_block is the dimensionality of the block.
- pos_enc (torch.Tensor) – Positional embedding sequences of shape (B, 2 * (T - 1), D_block), representing the position of each token in the sequence.
- mask (torch.Tensor) – Source mask of shape (B, T) to specify which tokens should be attended to.
- chunk_mask (Optional *[*torch.Tensor ]) – Optional chunk mask of shape (T_2, T_2) to control attention within chunks.
- Returns: A tuple containing: : - x (torch.Tensor): E-Branchformer output sequences of shape : (B, T, D_block).
- mask (torch.Tensor): Source mask of shape (B, T).
- pos_enc (torch.Tensor): Positional embedding sequences of shape : (B, 2 * (T - 1), D_block).
- Return type: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
########### Examples
>>> model = EBranchformer(block_size=256, linear_size=128,
... self_att=self_attention_module,
... feed_forward=feed_forward_module,
... feed_forward_macaron=feed_forward_macaron_module,
... conv_mod=conv_module,
... depthwise_conv_mod=depthwise_conv_module)
>>> input_tensor = torch.randn(32, 10, 256) # Batch of 32, 10 time steps
>>> pos_enc = torch.randn(32, 18, 256) # Positional encodings
>>> mask = torch.ones(32, 10) # Full attention
>>> output, mask_out, pos_enc_out = model.forward(input_tensor, pos_enc, mask)
reset_streaming_cache(left_context: int, device: device) → None
Initialize/Reset self-attention and convolution modules cache for streaming.
This method initializes or resets the cache for the self-attention and convolution modules to facilitate streaming processing. The cache is used to store intermediate results from previous frames, allowing the model to maintain context over longer sequences without having to recompute past information.
- Parameters:
- left_context – Number of previous frames the attention module can see in current chunk. This parameter controls how much of the past context is considered during attention calculations.
- device – Device to use for cache tensor. This allows the cache to be created on the appropriate hardware (CPU or GPU) for efficient processing.
########### Examples
>>> model = EBranchformer(...)
>>> model.reset_streaming_cache(left_context=5, device=torch.device('cuda'))
####### NOTE This method should be called before processing a new input chunk to ensure that the cache is properly initialized.