espnet2.speechlm.module.transformer.TransformerDecoder
espnet2.speechlm.module.transformer.TransformerDecoder
class espnet2.speechlm.module.transformer.TransformerDecoder(n_ctx: int, n_state: int, n_head: int, n_layer: int, causal: bool = True, layer_class=<class 'espnet2.speechlm.module.transformer.ResidualAttentionBlock'>)
Bases: Module
TransformerDecoder implements a Transformer Decoder-Only Architecture.
This class is part of the ESPnet framework and provides a straightforward implementation of a Transformer decoder. It supports stacked Transformer layers and positional embeddings but does not include an embedding table or language model head. The attention mechanism utilizes PyTorch’s built-in flash attention, and it is recommended to use compatible PyTorch versions (2.0.1 or higher).
pos_emb
Positional embeddings for the input sequences.
- Type: nn.Embedding
blocks
List of residual attention blocks.
- Type: nn.ModuleList
ln
Layer normalization applied at the end.
- Type:LayerNorm
causal
Indicates if the decoder should operate in causal mode.
Type: bool
Parameters:
- n_ctx (int) – The number of context tokens.
- n_state (int) – The size of the hidden state.
- n_head (int) – The number of attention heads.
- n_layer (int) – The number of stacked layers.
- causal (bool) – If True, enables causal attention (default: True).
- layer_class (type) – The class used for the attention layers (default: ResidualAttentionBlock).
Returns: The output tensor after passing through the decoder.
Return type: Tensor
Raises:ValueError – If causal attention is enabled and a mask is provided.
####### Examples
>>> decoder = TransformerDecoder(n_ctx=512, n_state=768, n_head=12,
... n_layer=6)
>>> x = torch.randn(1, 10, 768) # (batch_size, seq_len, n_state)
>>> output = decoder(x) # output shape: (1, 10, 768)
NOTE
This implementation intentionally remains simple and may not cover all configuration choices available in other libraries like HuggingFace.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Applies the forward pass of the TransformerDecoder.
This method processes the input tensor x through the Transformer decoder architecture, applying positional embeddings and a series of attention blocks. It supports causal and masked attention mechanisms.
- Parameters:
- x (Tensor) – The input tensor of shape (batch_size, seq_length, n_state).
- mask (torch.Tensor , optional) – A tensor used to mask out certain positions in the input. This is not allowed if causal is True.
- kv_cache (Optional *[*dict ] , optional) – A cache for key-value pairs used in attention layers to improve efficiency during decoding. If provided, it should contain pre-computed keys and values for faster cross-attention.
- Returns: The output tensor after passing through the Transformer decoder layers of shape (batch_size, seq_length, n_state).
- Return type: Tensor
- Raises:ValueError – If causal is True and mask is not None.
####### Examples
>>> decoder = TransformerDecoder(n_ctx=512, n_state=768, n_head=12, n_layer=6)
>>> input_tensor = torch.randn(2, 10, 768) # (batch_size, seq_length, n_state)
>>> output = decoder(input_tensor)
>>> output.shape
torch.Size([2, 10, 768])