espnet2.asr.decoder.transformer_decoder.TransformerMDDecoder
espnet2.asr.decoder.transformer_decoder.TransformerMDDecoder
class espnet2.asr.decoder.transformer_decoder.TransformerMDDecoder(vocab_size: int, encoder_output_size: int, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, self_attention_dropout_rate: float = 0.0, src_attention_dropout_rate: float = 0.0, input_layer: str = 'embed', use_output_layer: bool = True, pos_enc_class=<class 'espnet.nets.pytorch_backend.transformer.embedding.PositionalEncoding'>, normalize_before: bool = True, concat_after: bool = False, use_speech_attn: bool = True)
Bases: BaseTransformerDecoder
Transformer decoder with multi-dimensional attention.
This class implements a Transformer decoder that integrates multi-dimensional attention capabilities. It extends the BaseTransformerDecoder class and allows the incorporation of speech attention, enabling enhanced performance in speech-related tasks.
use_speech_attn
Flag indicating whether to use speech attention in the decoding process.
Type: bool
Parameters:
- vocab_size (int) – Size of the vocabulary for output tokens.
- encoder_output_size (int) – Dimensionality of the encoder’s output.
- attention_heads (int) – Number of attention heads for multi-head attention (default: 4).
- linear_units (int) – Number of units in the position-wise feed forward layer (default: 2048).
- num_blocks (int) – Number of decoder blocks (default: 6).
- dropout_rate (float) – Dropout rate applied to the layers (default: 0.1).
- positional_dropout_rate (float) – Dropout rate applied to positional encoding (default: 0.1).
- self_attention_dropout_rate (float) – Dropout rate for self attention (default: 0.0).
- src_attention_dropout_rate (float) – Dropout rate for source attention (default: 0.0).
- input_layer (str) – Type of input layer; either “embed” or “linear” (default: “embed”).
- use_output_layer (bool) – Flag indicating whether to use an output layer (default: True).
- pos_enc_class – Class for positional encoding (default: PositionalEncoding).
- normalize_before (bool) – Flag indicating whether to apply layer normalization before the first block (default: True).
- concat_after (bool) – Flag indicating whether to concatenate attention layer’s input and output (default: False).
- use_speech_attn (bool) – Flag indicating whether to use speech attention (default: True).
Returns: Tuple containing the decoded token scores before softmax and the lengths of the output sequences.
Return type: Tuple[torch.Tensor, torch.Tensor]
############# Examples
>>> decoder = TransformerMDDecoder(vocab_size=1000,
... encoder_output_size=512)
>>> hs_pad = torch.randn(32, 10, 512) # (batch, maxlen_in, feat)
>>> hlens = torch.tensor([10] * 32) # (batch)
>>> ys_in_pad = torch.randint(0, 1000, (32, 20)) # (batch, maxlen_out)
>>> ys_in_lens = torch.tensor([20] * 32) # (batch)
>>> output, olens = decoder(hs_pad, hlens, ys_in_pad, ys_in_lens)
######### NOTE This decoder is particularly useful for tasks that involve processing both text and speech inputs, enabling better context understanding and more accurate outputs.
- Raises:ValueError – If the input_layer argument is not one of “embed” or “linear”.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
batch_score(ys: Tensor, states: List[Any], xs: Tensor, speech: Tensor | None = None) → Tuple[Tensor, List[Any]]
Score new token batch.
- Parameters:
- ys (torch.Tensor) – A tensor of shape (n_batch, ylen) containing the prefix tokens as int64.
- states (List *[*Any ]) – A list of scorer states corresponding to the prefix tokens.
- xs (torch.Tensor) – A tensor of shape (n_batch, xlen, n_feat) representing the encoder features that generate ys.
- speech (torch.Tensor , optional) – A tensor of shape (n_batch, s_len, n_feat) representing the encoded speech features. Defaults to None.
- Returns: A tuple containing: : - A tensor of shape (n_batch, n_vocab) representing the batchified scores for the next token.
- A list of next state lists for ys.
- Return type: tuple[torch.Tensor, List[Any]]
############# Examples
>>> decoder = TransformerMDDecoder(vocab_size=5000, encoder_output_size=256)
>>> ys = torch.randint(0, 5000, (2, 10)) # Example input
>>> states = [None, None] # Example states
>>> xs = torch.rand(2, 15, 256) # Example encoder features
>>> logp, next_states = decoder.batch_score(ys, states, xs)
>>> print(logp.shape) # Output: torch.Size([2, 5000])
######### NOTE This method performs batch scoring for the next tokens given the input sequences and their corresponding states. It supports optional speech features to be used during scoring.
forward(hs_pad: Tensor, hlens: Tensor, ys_in_pad: Tensor, ys_in_lens: Tensor, speech: Tensor | None = None, speech_lens: Tensor | None = None, return_hs: bool = False) → Tuple[Tensor, Tensor]
Forward pass of the TransformerMDDecoder.
This method performs the forward pass of the TransformerMDDecoder, processing the encoded memory and input token sequences to produce decoded token scores. It can also utilize speech features if provided.
- Parameters:
- hs_pad (torch.Tensor) – Encoded memory, shape (batch, maxlen_in, feat).
- hlens (torch.Tensor) – Lengths of the encoder outputs, shape (batch).
- ys_in_pad (torch.Tensor) – Input token IDs, shape (batch, maxlen_out). If input_layer is “embed”, otherwise, input tensor shape is (batch, maxlen_out, #mels).
- ys_in_lens (torch.Tensor) – Lengths of the input sequences, shape (batch).
- speech (torch.Tensor , optional) – Encoded speech features, shape (batch, maxlen_in, feat). Defaults to None.
- speech_lens (torch.Tensor , optional) – Lengths of the speech sequences, shape (batch). Defaults to None.
- return_hs (bool , optional) – If True, return the last hidden state corresponding to the output. Defaults to False.
- Returns: A tuple containing: : - x (torch.Tensor): Decoded token scores before softmax, shape (batch, maxlen_out, vocab_size) if use_output_layer is True.
- olens (torch.Tensor): Lengths of the output sequences, shape (batch,).
- Return type: Tuple[torch.Tensor, torch.Tensor]
############# Examples
>>> hs_pad = torch.randn(32, 50, 256) # Example memory tensor
>>> hlens = torch.randint(1, 51, (32,))
>>> ys_in_pad = torch.randint(0, 100, (32, 20)) # Example input IDs
>>> ys_in_lens = torch.randint(1, 21, (32,))
>>> decoder = TransformerMDDecoder(100, 256)
>>> output, output_lengths = decoder(hs_pad, hlens, ys_in_pad, ys_in_lens)
######### NOTE If speech is provided, the decoder will leverage the speech attention mechanism for enhanced decoding performance.
forward_one_step(tgt: Tensor, tgt_mask: Tensor, memory: Tensor, memory_mask: Tensor | None = None, *, speech: Tensor | None = None, speech_mask: Tensor | None = None, cache: List[Tensor] | None = None, return_hs: bool = False) → Tuple[Tensor, List[Tensor]]
Forward one step through the decoder.
This method performs a single step of decoding, where the model generates the next token based on the current input tokens, the encoded memory from the encoder, and optionally, speech features and their masks. The output is the predicted token and updated cache for future decoding steps.
- Parameters:
- tgt – Input token IDs, of shape (batch, maxlen_out), where each entry is an integer representing a token.
- tgt_mask – Input token mask of shape (batch, maxlen_out). It can be of dtype torch.uint8 in PyTorch 1.2- or dtype torch.bool in PyTorch 1.2+.
- memory – Encoded memory from the encoder, of shape (batch, maxlen_in, feat).
- memory_mask – (Optional) Mask for the memory, of shape (batch, 1, maxlen_in).
- speech – (Optional) Encoded speech features, of shape (batch, maxlen_in, feat).
- speech_mask – (Optional) Mask for the speech, of shape (batch, 1, maxlen_in).
- cache – (Optional) Cached output list from previous steps, of shape (batch, max_time_out-1, size).
- return_hs – Whether to return the hidden state corresponding to the input tokens, useful for debugging or further processing.
- Returns:
- y: Output token scores, of shape (batch, maxlen_out, token). : This is the log probabilities of the next token.
- cache: Updated cache for future decoding steps.
- Return type: A tuple containing
############# Examples
>>> tgt = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> tgt_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
>>> memory = torch.rand(2, 10, 512) # Random memory
>>> output, new_cache = decoder.forward_one_step(tgt, tgt_mask, memory)
######### NOTE Ensure that the input tensor dimensions are compatible with the model’s expectations. The speech and speech_mask parameters are optional and should only be provided if speech features are being used in the decoding process.
score(ys, state, x, speech=None)
Calculate the score for a given sequence and update the state.
This method computes the log probability of the next token in the sequence based on the provided input features and updates the internal state of the decoder.
- Parameters:
- ys (torch.Tensor) – The input token IDs of shape (n_tokens,).
- state (List *[*torch.Tensor ]) – The cached state of the decoder.
- x (torch.Tensor) – The encoder output features of shape (1, input_length, feature_dim).
- speech (torch.Tensor , optional) – The encoded speech features of shape (1, input_length, feature_dim). Defaults to None.
- Returns:
- logp (torch.Tensor): The log probability of the next token of shape (vocab_size,).
- state (List[torch.Tensor]): The updated state of the decoder.
- Return type: Tuple[torch.Tensor, List[torch.Tensor]]
############# Examples
>>> ys = torch.tensor([1, 2, 3]) # Example token IDs
>>> state = [None] # Initial state
>>> x = torch.randn(1, 10, 512) # Example encoder output
>>> logp, updated_state = decoder.score(ys, state, x)
######### NOTE The method can handle optional speech features if provided. This is useful for tasks that involve speech input.