espnet2.asr_transducer.encoder.blocks.branchformer.Branchformer
espnet2.asr_transducer.encoder.blocks.branchformer.Branchformer
class espnet2.asr_transducer.encoder.blocks.branchformer.Branchformer(block_size: int, linear_size: int, self_att: ~torch.nn.modules.module.Module, conv_mod: ~torch.nn.modules.module.Module, norm_class: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, norm_args: ~typing.Dict = {}, dropout_rate: float = 0.0)
Bases: Module
Branchformer block for Transducer encoder.
This class implements the Branchformer module, which is designed to enhance the encoding capabilities of a transducer model. It combines self-attention and convolutional layers while utilizing normalization and dropout for improved performance.
Reference: https://arxiv.org/pdf/2207.02971.pdf
self_att
The self-attention module instance.
- Type: torch.nn.Module
conv_mod
The convolution module instance.
- Type: torch.nn.Module
channel_proj1
A sequential layer for channel projection.
- Type: torch.nn.Sequential
channel_proj2
A linear layer for projecting back to the original block size.
- Type: torch.nn.Linear
merge_proj
A linear layer for merging outputs from attention and convolution.
- Type: torch.nn.Linear
norm_self_att
Normalization layer for self-attention.
- Type: torch.nn.Module
norm_mlp
Normalization layer for the MLP.
- Type: torch.nn.Module
norm_final
Final normalization layer.
- Type: torch.nn.Module
dropout
Dropout layer for regularization.
- Type: torch.nn.Dropout
block_size
Input/output size.
- Type: int
linear_size
Linear layers’ hidden size.
- Type: int
cache
Cache for storing intermediate results during streaming.
Type: Optional[List[torch.Tensor]]
Parameters:
- block_size (int) – Input/output size.
- linear_size (int) – Linear layers’ hidden size.
- self_att (torch.nn.Module) – Self-attention module instance.
- conv_mod (torch.nn.Module) – Convolution module instance.
- norm_class (torch.nn.Module , optional) – Normalization class. Defaults to torch.nn.LayerNorm.
- norm_args (Dict , optional) – Normalization module arguments. Defaults to {}.
- dropout_rate (float , optional) – Dropout rate. Defaults to 0.0.
######### Examples
Initialize a Branchformer module
branchformer = Branchformer(
block_size=256, linear_size=512, self_att=my_self_attention_module, conv_mod=my_convolution_module, norm_class=torch.nn.LayerNorm, norm_args={‘eps’: 1e-6}, dropout_rate=0.1
)
Forward pass
output, mask, pos_enc = branchformer(x, pos_enc, mask)
Reset cache for streaming
branchformer.reset_streaming_cache(left_context=10, device=torch.device(‘cuda’))
- Raises:ValueError – If the input tensor dimensions do not match the expected sizes.
Construct a Branchformer object.
#
chunk_forward(
Encode chunk of input sequence.
- Parameters:
- x – Branchformer input sequences. (B, T, D_block)
- pos_enc – Positional embedding sequences. (B, 2 * (T - 1), D_block)
- mask – Source mask. (B, T_2)
- left_context – Number of previous frames the attention module can see in current chunk.
- Returns: Branchformer output sequences. (B, T, D_block) pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- Return type: x
#
forward(
Branchformer block for Transducer encoder.
This module implements the Branchformer architecture, which is designed for efficient sequence encoding using self-attention and convolutional mechanisms.
Reference: : https://arxiv.org/pdf/2207.02971.pdf
- Parameters:
- block_size (int) – Input/output size.
- linear_size (int) – Linear layers’ hidden size.
- self_att (torch.nn.Module) – Self-attention module instance.
- conv_mod (torch.nn.Module) – Convolution module instance.
- norm_class (torch.nn.Module , optional) – Normalization class, defaults to torch.nn.LayerNorm.
- norm_args (Dict , optional) – Normalization module arguments, defaults to {}.
- dropout_rate (float , optional) – Dropout rate, defaults to 0.0.
reset_streaming_cache(left_context
int, device: torch.device) -> None: Initializes or resets the self-attention and convolution modules’ cache for streaming.
forward(
x: torch.Tensor, pos_enc: torch.Tensor, mask: torch.Tensor, chunk_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Encodes input sequences.
chunk_forward(
x: torch.Tensor, pos_enc: torch.Tensor, mask: torch.Tensor, left_context: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]
Encodes a chunk of the input sequence.
######### Examples
Example usage of the Branchformer module
branchformer = Branchformer(
block_size=256, linear_size=512, self_att=torch.nn.MultiheadAttention(embed_dim=256, num_heads=8), conv_mod=torch.nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3)
)
x = torch.randn(32, 10, 256) # (B, T, D_block) pos_enc = torch.randn(32, 18, 256) # (B, 2 * (T - 1), D_block) mask = torch.ones(32, 10) # (B, T)
output, mask_out, pos_enc_out = branchformer.forward(x, pos_enc, mask)
#
reset_streaming_cache(left_context
Initialize/Reset self-attention and convolution modules cache for streaming.
This method resets the internal cache used by the self-attention and convolution modules. It creates new tensors to hold the cached values for both attention and convolution, which are essential for processing streaming data.
- Parameters:
- left_context – Number of previous frames the attention module can see in the current chunk. This defines how much context is available for attention calculations.
- device – Device to use for cache tensor. This specifies where the cache tensors should be allocated (e.g., CPU or GPU).
######### Examples
>>> model = Branchformer(...)
>>> model.reset_streaming_cache(left_context=10, device=torch.device('cuda'))
NOTE
This method should be called whenever the input context changes or when starting a new streaming session to ensure the cache is appropriately initialized.