espnet2.asr.decoder.transformer_decoder.BaseTransformerDecoder
espnet2.asr.decoder.transformer_decoder.BaseTransformerDecoder
class espnet2.asr.decoder.transformer_decoder.BaseTransformerDecoder(vocab_size: int, encoder_output_size: int, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, input_layer: str = 'embed', use_output_layer: bool = True, pos_enc_class=<class 'espnet.nets.pytorch_backend.transformer.embedding.PositionalEncoding'>, normalize_before: bool = True)
Bases: AbsDecoder
, BatchScorerInterface
, MaskParallelScorerInterface
Base class of Transformer decoder module.
This class implements a base structure for a Transformer decoder used in automatic speech recognition (ASR) tasks. It defines the main architecture, input handling, and forward propagation methods required for decoding.
embed
Input embedding layer, which can be an embedding layer followed by positional encoding or a linear layer.
- Type: torch.nn.Sequential
after_norm
Layer normalization applied before the first decoder block if normalize_before is set to True.
- Type:LayerNorm
output_layer
Output layer that maps decoder outputs to the vocabulary size, if use_output_layer is True.
- Type: torch.nn.Linear
_output_size_bf_softmax
Dimension of the output before applying softmax, set to the attention dimension.
- Type: int
decoders
List of decoder layers (to be set by inheritance).
- Type: List[DecoderLayer]
batch_ids
Tensor containing batch IDs for processing.
Type: torch.Tensor
Parameters:
- vocab_size (int) – The size of the output vocabulary.
- encoder_output_size (int) – The dimension of the encoder’s output.
- dropout_rate (float , optional) – Dropout rate for regularization. Default is 0.1.
- positional_dropout_rate (float , optional) – Dropout rate for positional encoding. Default is 0.1.
- input_layer (str , optional) – Type of input layer (‘embed’ or ‘linear’). Default is ‘embed’.
- use_output_layer (bool , optional) – Whether to use an output layer. Default is True.
- pos_enc_class (Type *[*PositionalEncoding ] , optional) – Class for positional encoding. Default is PositionalEncoding.
- normalize_before (bool , optional) – Whether to apply layer normalization before the first decoder block. Default is True.
############### Examples
Initialize the decoder
decoder = BaseTransformerDecoder(
vocab_size=5000, encoder_output_size=256, dropout_rate=0.1, input_layer=’embed’
)
Forward pass
output, lengths = decoder(
hs_pad=encoded_memory_tensor, hlens=encoded_memory_lengths, ys_in_pad=input_token_ids, ys_in_lens=input_lengths
)
- Raises:ValueError – If input_layer is not ‘embed’ or ‘linear’.
########## NOTE This class should be inherited by specific transformer decoder implementations to provide concrete decoder layer configurations.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
batch_score(ys: Tensor, states: List[Any], xs: Tensor, return_hs: bool = False) → Tuple[Tensor, List[Any]]
Score new token batch.
This method computes the scores for the next token based on the input tokens, the current states, and the encoder features. It can return hidden states if requested.
- Parameters:
- ys (torch.Tensor) – A tensor of shape (n_batch, ylen) containing the prefix tokens in int64 format.
- states (List *[*Any ]) – A list of states for the prefix tokens, where each state corresponds to the current state of the decoder.
- xs (torch.Tensor) – A tensor of shape (n_batch, xlen, n_feat) representing the encoder features that generated the prefix tokens.
- return_hs (bool , optional) – If True, the method will return the hidden states along with the scores. Defaults to False.
- Returns: A tuple containing: : - A tensor of shape (n_batch, n_vocab) representing the scores for the next token.
- A list of next state lists for the prefix tokens.
- Return type: Tuple[torch.Tensor, List[Any]]
############### Examples
>>> decoder = TransformerDecoder(vocab_size=100, encoder_output_size=256)
>>> ys = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> states = [None, None]
>>> xs = torch.rand(2, 10, 256)
>>> scores, next_states = decoder.batch_score(ys, states, xs)
>>> print(scores.shape) # Should print: torch.Size([2, 100])
########## NOTE Ensure that the input tensors are properly shaped and that the states list matches the expected format for the decoder.
- Raises:
- ValueError – If the dimensions of the input tensors do not match
- the expected shapes. –
batch_score_partially_AR(ys: Tensor, states: List[Any], xs: Tensor, yseq_lengths: Tensor) → Tuple[Tensor, List[Any]]
Score a batch of new tokens in a partially autoregressive manner.
This method evaluates a batch of token sequences and returns their scores along with the updated states. It is specifically designed for use in scenarios where the sequences are generated in a partially autoregressive manner.
- Parameters:
- ys (torch.Tensor) – Tensor of shape (n_mask * n_beam, ylen) containing the token sequences for which scores are to be computed. Each element is an integer representing a token ID.
- states (List *[*Any ]) – A list of states for the scorer, where each state corresponds to a prefix of tokens in ys.
- xs (torch.Tensor) – Tensor of shape (n_batch, xlen, n_feat) representing the encoder features that are used to generate the scores for the sequences in ys.
- yseq_lengths (torch.Tensor) – Tensor of shape (n_mask * n_beam,) containing the lengths of the sequences in ys, used to create the appropriate attention masks.
- Returns: A tuple where the first element is a tensor of shape (n_mask * n_beam, n_vocab) containing the scores for the next token in the sequence, and the second element is a list of updated states for each sequence in ys.
- Return type: Tuple[torch.Tensor, List[Any]]
############### Examples
>>> ys = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> states = [None, None]
>>> xs = torch.randn(2, 10, 256) # Example encoder features
>>> yseq_lengths = torch.tensor([3, 3])
>>> scores, updated_states = batch_score_partially_AR(ys, states, xs,
... yseq_lengths)
########## NOTE This method requires that the lengths in yseq_lengths are consistent with the sequences in ys. It also assumes that the appropriate padding has been applied to ys and xs.
forward(hs_pad: Tensor, hlens: Tensor, ys_in_pad: Tensor, ys_in_lens: Tensor, return_hs: bool = False, return_all_hs: bool = False) → Tuple[Tensor, Tensor]
Forward decoder.
This method takes the encoded memory and input token IDs to produce the output token scores before softmax. It can also return hidden states if specified.
- Parameters:
- hs_pad (torch.Tensor) – Encoded memory with shape (batch, maxlen_in, feat).
- hlens (torch.Tensor) – Lengths of the encoded sequences with shape (batch).
- ys_in_pad (torch.Tensor) – Input token IDs with shape (batch, maxlen_out). If input_layer is “embed”, it represents token IDs; otherwise, it should be a tensor of shape (batch, maxlen_out, #mels).
- ys_in_lens (torch.Tensor) – Lengths of the input sequences with shape (batch).
- return_hs (bool , optional) – Whether to return the last hidden output before the output layer. Defaults to False.
- return_all_hs (bool , optional) – Whether to return all hidden intermediate states. Defaults to False.
- Returns: A tuple containing: : - x (torch.Tensor): Decoded token scores before softmax with shape <br/> (batch, maxlen_out, token) if use_output_layer is True.
- olens (torch.Tensor): Lengths of the output sequences with shape (batch,).
- Return type: Tuple[torch.Tensor, torch.Tensor]
############### Examples
>>> hs_pad = torch.randn(2, 10, 512) # Example memory
>>> hlens = torch.tensor([10, 8]) # Lengths of memory
>>> ys_in_pad = torch.tensor([[1, 2, 3], [4, 5, 6]]) # Input token IDs
>>> ys_in_lens = torch.tensor([3, 3]) # Lengths of input sequences
>>> output_scores, output_lengths = decoder.forward(hs_pad, hlens, ys_in_pad, ys_in_lens)
########## NOTE Ensure that the input tensors are appropriately padded and have the correct dimensions as described in the arguments.
- Raises:
- ValueError – If the shapes of input tensors do not match the expected
- dimensions or if any input is malformed. –
forward_one_step(tgt: Tensor, tgt_mask: Tensor, memory: Tensor, memory_mask: Tensor | None = None, *, cache: List[Tensor] | None = None, return_hs: bool = False) → Tuple[Tensor, List[Tensor]]
Forward one step.
- Parameters:
- tgt – input token ids, int64 (batch, maxlen_out)
- tgt_mask – input token mask, (batch, maxlen_out) dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (include 1.2)
- memory – encoded memory, float32 (batch, maxlen_in, feat)
- memory_mask – encoded memory mask (batch, 1, maxlen_in)
- cache – cached output list of (batch, max_time_out-1, size)
- return_hs – dec hidden state corresponding to ys, used for searchable hidden ints
- Returns: NN output value and cache per self.decoders. y.shape` is (batch, maxlen_out, token)
- Return type: y, cache
forward_partially_AR(tgt: Tensor, tgt_mask: Tensor, tgt_lengths: Tensor, memory: Tensor, cache: List[Tensor] | None = None) → Tuple[Tensor, List[Tensor]]
Forward one step in a partially autoregressive manner.
This method processes the input tokens and computes the output scores for the next tokens in a partially autoregressive manner, allowing for efficient decoding during beam search or similar scenarios.
- Parameters:
- tgt – Input token ids, int64 of shape (n_mask * n_beam, maxlen_out).
- tgt_mask – Input token mask of shape (n_mask * n_beam, maxlen_out). The data type should be torch.uint8 for PyTorch versions < 1.2 and torch.bool for PyTorch versions >= 1.2.
- tgt_lengths – Lengths of the input sequences, shape (n_mask * n_beam,).
- memory – Encoded memory from the encoder, float32 of shape (batch, maxlen_in, feat).
- cache – Cached output list for each decoder layer, which can be used to store previous hidden states and facilitate efficient decoding.
- Returns:
- y: Output token scores, float32 of shape (n_mask * n_beam, maxlen_out, vocab_size).
- cache: Updated cache containing hidden states for each decoder layer.
- Return type: Tuple[torch.Tensor, List[torch.Tensor]]
############### Examples
>>> decoder = TransformerDecoder(vocab_size=1000,
... encoder_output_size=512)
>>> tgt = torch.randint(0, 1000, (2, 10)) # 2 sequences of length 10
>>> tgt_mask = torch.ones((2, 10), dtype=torch.bool)
>>> tgt_lengths = torch.tensor([10, 10])
>>> memory = torch.randn(2, 15, 512) # 2 sequences of length 15
>>> output, updated_cache = decoder.forward_partially_AR(
... tgt, tgt_mask, tgt_lengths, memory
... )
########## NOTE This method is particularly useful in scenarios where decoding needs to be efficient, such as during beam search or when generating sequences with partial context.
score(ys, state, x, return_hs=False)
Compute the score for a given input sequence.
This method performs a single step of scoring for the input sequence ys based on the provided state and encoder output x. It can optionally return the hidden state if requested.
- Parameters:
- ys (torch.Tensor) – Input token ids, shape (maxlen_out), of type int64.
- state (List *[*torch.Tensor ]) – Cached states for the decoder from previous steps.
- x (torch.Tensor) – Encoded memory, shape (1, maxlen_in, feat), of type float32.
- return_hs (bool , optional) – If True, the hidden state corresponding to the input tokens will be returned. Defaults to False.
- Returns:
- logp (torch.Tensor): Log probabilities of the next token, shape (vocab_size).
- state (List[torch.Tensor]): Updated state after scoring.
- Return type: Tuple[torch.Tensor, List[torch.Tensor]]
############### Examples
>>> decoder = TransformerDecoder(vocab_size=5000,
... encoder_output_size=256)
>>> ys = torch.tensor([1, 2, 3]) # Example token ids
>>> state = [None] # Initial state
>>> x = torch.randn(1, 10, 256) # Example encoder output
>>> logp, new_state = decoder.score(ys, state, x)
########## NOTE This method is typically used in the decoding process to compute the next token’s probability given the previous tokens and the encoder output.