espnet2.asr.encoder.contextual_block_transformer_encoder.ContextualBlockTransformerEncoder
espnet2.asr.encoder.contextual_block_transformer_encoder.ContextualBlockTransformerEncoder
class espnet2.asr.encoder.contextual_block_transformer_encoder.ContextualBlockTransformerEncoder(input_size: int, output_size: int = 256, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: str | None = 'conv2d', pos_enc_class=<class 'espnet.nets.pytorch_backend.transformer.embedding.StreamPositionalEncoding'>, normalize_before: bool = True, concat_after: bool = False, positionwise_layer_type: str = 'linear', positionwise_conv_kernel_size: int = 1, padding_idx: int = -1, block_size: int = 40, hop_size: int = 16, look_ahead: int = 16, init_average: bool = True, ctx_pos_enc: bool = True)
Bases: AbsEncoder
Contextual Block Transformer encoder module.
This class implements a transformer encoder that utilizes contextual block processing as described in Tsunoo et al. “Transformer ASR with contextual block processing” (https://arxiv.org/abs/1910.07204).
output_size
The dimension of the output from the encoder.
- Type: int
pos_enc
Positional encoding layer used to add positional information to the input embeddings.
embed
Input layer that transforms the input features into the appropriate embedding size.
encoders
A series of contextual block encoder layers.
normalize_before
Indicates if layer normalization should be applied before the first encoder block.
- Type: bool
after_norm
Layer normalization applied after the encoder layers.
block_size
Size of the blocks for contextual processing.
- Type: int
hop_size
The step size to move between blocks during processing.
- Type: int
look_ahead
The size of the look-ahead window for processing.
- Type: int
init_average
Indicates if the initial context vector should be an average of the block.
- Type: bool
ctx_pos_enc
Whether to apply positional encoding to the context vectors.
Type: bool
Parameters:
- input_size (int) – Input dimension.
- output_size (int , optional) – Dimension of attention (default: 256).
- attention_heads (int , optional) – Number of heads for multi-head attention (default: 4).
- linear_units (int , optional) – Number of units in position-wise feedforward layer (default: 2048).
- num_blocks (int , optional) – Number of encoder blocks (default: 6).
- dropout_rate (float , optional) – Dropout rate (default: 0.1).
- positional_dropout_rate (float , optional) – Dropout rate after adding positional encoding (default: 0.1).
- attention_dropout_rate (float , optional) – Dropout rate in attention (default: 0.0).
- input_layer (str , optional) – Type of input layer (default: “conv2d”).
- pos_enc_class – Class for positional encoding (default: StreamPositionalEncoding).
- normalize_before (bool , optional) – Use layer normalization before the first block (default: True).
- concat_after (bool , optional) – Concatenate input and output of attention layer (default: False).
- positionwise_layer_type (str , optional) – Type of position-wise layer (“linear” or “conv1d”, default: “linear”).
- positionwise_conv_kernel_size (int , optional) – Kernel size for position-wise conv1d layer (default: 1).
- padding_idx (int , optional) – Padding index for input_layer=embed (default: -1).
- block_size (int , optional) – Block size for contextual processing (default: 40).
- hop_size (int , optional) – Hop size for block processing (default: 16).
- look_ahead (int , optional) – Look-ahead size for block processing (default: 16).
- init_average (bool , optional) – Use average as initial context (default: True).
- ctx_pos_enc (bool , optional) – Use positional encoding for context vectors (default: True).
########### Examples
encoder = ContextualBlockTransformerEncoder( : input_size=80, output_size=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, block_size=40, hop_size=16, look_ahead=16,
)
Forward pass
xs_pad = torch.randn(10, 100, 80) # (Batch, Length, Dimension) ilens = torch.tensor([100] * 10) # Input lengths output, olens, _ = encoder(xs_pad, ilens)
####### NOTE This encoder is specifically designed for automatic speech recognition (ASR) tasks using transformer architectures with a focus on contextual processing.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(xs_pad: Tensor, ilens: Tensor, prev_states: Tensor | None = None, is_final=True, infer_mode=False) → Tuple[Tensor, Tensor, Tensor | None]
Embed positions in tensor.
- Parameters:
- xs_pad – input tensor (B, L, D)
- ilens – input length (B)
- prev_states – Not to be used now.
- infer_mode – whether to be used for inference. This is used to distinguish between forward_train (train and validate) and forward_infer (decode).
- Returns: position embedded tensor and mask
forward_infer(xs_pad: Tensor, ilens: Tensor, prev_states: Tensor | None = None, is_final: bool = True) → Tuple[Tensor, Tensor, Tensor | None]
Perform inference on input tensor using the encoder.
This method processes the input tensor during inference. It handles the previous states and manages the internal buffers required for block processing. The function is designed to work with a batch size of one and is intended for use in scenarios where the model is being used for generating outputs (e.g., decoding).
- Parameters:
- xs_pad – Input tensor of shape (B, L, D), where B is the batch size, L is the sequence length, and D is the feature dimension.
- ilens – Tensor containing the lengths of each input sequence in the batch (B).
- prev_states – Optional tensor containing the previous states for maintaining context across calls. If None, the function initializes new state variables.
- is_final – Boolean indicating if this is the final inference step. If True, the function finalizes any state management and returns the output tensor. Otherwise, it prepares for the next inference step.
- Returns:
- Tensor of shape (B, L_out, D) representing the output from the encoder, where L_out is the output sequence length.
- A tensor containing the output lengths (B).
- An optional dictionary with next states for continued processing, or None if this was the final step.
- Return type: A tuple containing
########### Examples
>>> encoder = ContextualBlockTransformerEncoder(...)
>>> xs_pad = torch.randn(1, 50, 256) # Example input
>>> ilens = torch.tensor([50]) # Lengths of input
>>> output, output_lengths, next_states = encoder.forward_infer(xs_pad, ilens)
####### NOTE This method assumes a batch size of one. If the input tensor has a different batch size, an assertion error will be raised.
- Raises:AssertionError – If the batch size of the input tensor is not 1.
forward_train(xs_pad: Tensor, ilens: Tensor, prev_states: Tensor | None = None) → Tuple[Tensor, Tensor, Tensor | None]
Perform forward pass for training and validation.
This method processes the input tensor during training, applying contextual block processing and attention mechanisms to produce the output tensor. The function also generates masks for the input sequence to manage padding effectively.
- Parameters:
- xs_pad – Input tensor of shape (B, L, D) where B is the batch size, L is the sequence length, and D is the feature dimension.
- ilens – Input lengths tensor of shape (B) indicating the actual lengths of each sequence in the batch.
- prev_states – (Optional) A dictionary containing the previous states for stateful processing. Currently, it is not used.
- Returns:
- Output tensor after processing (B, L’, D), where L’ is the length of the output sequence.
- A tensor containing the lengths of the output sequences (B).
- An optional tensor for future state management, currently set to None.
- Return type: Tuple containing
########### Examples
>>> encoder = ContextualBlockTransformerEncoder(...)
>>> xs_pad = torch.randn(2, 50, 256) # Example input
>>> ilens = torch.tensor([50, 40]) # Example input lengths
>>> output, olens, _ = encoder.forward_train(xs_pad, ilens)
####### NOTE The method is designed to handle both training and validation modes based on the state of the encoder.
- Raises:ValueError – If the input tensor dimensions do not match expected shapes or if an invalid input_layer type is provided during encoder initialization.
output_size
Get the output size of the encoder.
This method returns the output size of the encoder, which is defined during the initialization of the ContextualBlockTransformerEncoder class. The output size is typically the dimensionality of the attention mechanism used in the encoder.
- Returns: The output size of the encoder.
- Return type: int
########### Examples
>>> encoder = ContextualBlockTransformerEncoder(input_size=128)
>>> encoder.output_size()
256 # Assuming the default output_size is 256