espnet2.speechlm.module.transformer.ResidualAttentionBlock
espnet2.speechlm.module.transformer.ResidualAttentionBlock
class espnet2.speechlm.module.transformer.ResidualAttentionBlock(n_state: int, n_head: int, cross_attention: bool = False, causal: bool = False)
Bases: Module
A residual attention block that combines multi-head attention and feedforward
network with residual connections and layer normalization.
This block is designed to facilitate the construction of Transformer-like architectures, enabling both self-attention and cross-attention mechanisms. It applies layer normalization and residual connections to enhance training stability and model performance.
attn
The multi-head self-attention layer.
- Type:MultiHeadAttention
attn
Layer normalization applied to the self-attention output.
- Type:LayerNorm
cross_attn
The multi-head cross-attention layer if cross_attention is enabled.
- Type: Optional[MultiHeadAttention]
cross_attn
Layer normalization applied to the cross-attention output if cross_attention is enabled.
- Type: Optional[LayerNorm]
mlp
A feedforward network consisting of two linear layers with a GELU activation in between.
- Type: Sequential
mlp
Layer normalization applied to the output of the MLP.
Type:LayerNorm
Parameters:
- n_state (int) – The dimensionality of the input and output features.
- n_head (int) – The number of attention heads.
- cross_attention (bool , optional) – Whether to enable cross-attention. Defaults to False.
- causal (bool , optional) – Whether to enable causal attention. Defaults to False.
Returns: The output tensor after applying the attention and feedforward layers.
Return type: Tensor
####### Examples
>>> block = ResidualAttentionBlock(n_state=256, n_head=8)
>>> input_tensor = torch.randn(10, 32, 256) # (batch_size, seq_len, n_state)
>>> output = block(input_tensor)
>>> print(output.shape)
torch.Size([10, 32, 256])
NOTE
This module is designed to be used as a building block for Transformer architectures and should be integrated within a larger model for practical applications.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x: Tensor, xa: Tensor | None = None, mask: Tensor | None = None, kv_cache: dict | None = None)
Residual Attention Block implementing multi-head attention with residual
connections and layer normalization.
This module includes a multi-head self-attention mechanism, an optional cross-attention mechanism, and a feed-forward neural network (MLP). The attention mechanism can be causal or non-causal based on the configuration. Residual connections are applied to facilitate training deep networks.
attn
The multi-head self-attention layer.
- Type:MultiHeadAttention
attn
Layer normalization applied to the input of the attention layer.
- Type:LayerNorm
cross_attn
The multi-head cross-attention layer, if enabled.
- Type: Optional[MultiHeadAttention]
cross_attn
Layer normalization for the cross-attention input, if enabled.
- Type: Optional[LayerNorm]
mlp
A feed-forward network consisting of two linear layers with GELU activation in between.
- Type: Sequential
mlp
Layer normalization applied to the input of the MLP.
Type:LayerNorm
Parameters:
- n_state (int) – Dimensionality of the input and output features.
- n_head (int) – Number of attention heads.
- cross_attention (bool , optional) – Whether to enable cross-attention. Defaults to False.
- causal (bool , optional) – Whether to use causal attention. Defaults to False.
Returns: The output tensor after applying the attention and MLP layers.
Return type: Tensor
####### Examples
>>> block = ResidualAttentionBlock(n_state=256, n_head=8)
>>> input_tensor = torch.rand(10, 20, 256) # (batch_size, seq_len, n_state)
>>> output_tensor = block(input_tensor)
>>> output_tensor.shape
torch.Size([10, 20, 256])
- Raises:ValueError – If the cross-attention mechanism is used but no input tensor for cross-attention is provided.
NOTE
The input tensors should have the same feature size as n_state.