espnet2.asr.decoder.transducer_decoder.TransducerDecoder
espnet2.asr.decoder.transducer_decoder.TransducerDecoder
class espnet2.asr.decoder.transducer_decoder.TransducerDecoder(vocab_size: int, rnn_type: str = 'lstm', num_layers: int = 1, hidden_size: int = 320, dropout: float = 0.0, dropout_embed: float = 0.0, embed_pad: int = 0)
Bases: AbsDecoder
TransducerDecoder is an (RNN-)Transducer decoder module that processes sequences of input labels and produces corresponding output sequences. It is designed to work with recurrent neural networks (RNNs), specifically LSTM or GRU architectures.
embed
Embedding layer for input label sequences.
- Type: torch.nn.Embedding
dropout_embed
Dropout layer applied to the embeddings.
- Type: torch.nn.Dropout
decoder
List of RNN layers (LSTM or GRU).
- Type: torch.nn.ModuleList
dropout_dec
List of dropout layers applied to the RNN outputs.
- Type: torch.nn.ModuleList
dlayers
Number of decoder layers.
- Type: int
dunits
Number of decoder units per layer.
- Type: int
dtype
Type of RNN (‘lstm’ or ‘gru’).
- Type: str
odim
Size of the output vocabulary.
- Type: int
ignore_id
ID used to ignore certain labels in the decoding process.
- Type: int
blank_id
ID representing the blank symbol in the model.
- Type: int
device
The device (CPU or GPU) on which the model resides.
Type: torch.device
Parameters:
- vocab_size (int) – Size of the output vocabulary.
- rnn_type (str) – Type of RNN to use (‘lstm’ or ‘gru’). Default is ‘lstm’.
- num_layers (int) – Number of decoder layers. Default is 1.
- hidden_size (int) – Number of units in each decoder layer. Default is 320.
- dropout (float) – Dropout rate for the decoder layers. Default is 0.0.
- dropout_embed (float) – Dropout rate for the embedding layer. Default is 0.0.
- embed_pad (int) – Padding index for the embedding layer. Default is 0.
set_device(device)
Set the device to be used for the decoder.
init_state(batch_size)
Initialize the hidden states of the decoder.
rnn_forward(sequence, state)
Perform a forward pass through the RNN layers.
forward(labels)
Process input label sequences to produce decoder outputs.
score(hyp, cache)
Compute decoder output and hidden states for a single hypothesis.
batch_score(hyps, dec_states, cache, use_lm)
Compute decoder outputs for a batch of hypotheses.
select_state(states, idx)
Retrieve the hidden state for a specified index.
create_batch_states(states, new_states, check_list=None)
Create batch hidden states from new states.
################### Examples
Instantiate a TransducerDecoder
decoder = TransducerDecoder(
vocab_size=1000, rnn_type=’lstm’, num_layers=2, hidden_size=256, dropout=0.1, dropout_embed=0.1, embed_pad=0
)
Forward pass with input labels
labels = torch.randint(0, 1000, (32, 10)) # Batch of 32 sequences of length 10 outputs = decoder(labels)
Initialize states for a batch
init_states = decoder.init_state(batch_size=32)
Score a hypothesis
hyp = Hypothesis(yseq=[1, 2, 3], dec_state=init_states) dec_out, new_state, label = decoder.score(hyp, cache={})
########### NOTE The decoder requires input sequences to be properly padded and tokenized according to the model’s vocabulary.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
batch_score(hyps: List[Hypothesis] | List[ExtendedHypothesis], dec_states: Tuple[Tensor, Tensor | None], cache: Dict[str, Any], use_lm: bool) → Tuple[Tensor, Tuple[Tensor, Tensor], Tensor]
One-step forward hypotheses.
This method processes a batch of hypotheses, performing a one-step forward pass through the decoder for each hypothesis. It leverages a cache to avoid redundant computations, improving efficiency.
Parameters:
- hyps – A list of hypotheses to score, which can be either Hypothesis or ExtendedHypothesis instances.
- dec_states – Current decoder hidden states. This is a tuple containing two tensors: ((N, B, D_dec), (N, B, D_dec)), where N is the number of layers, B is the batch size, and D_dec is the number of decoder units.
- cache – A dictionary mapping label sequence strings to tuples of decoder output sequences and hidden states. This is used to store and retrieve previously computed results for efficiency.
- use_lm – A boolean indicating whether to compute label ID sequences for the language model (LM).
Returns: The decoder output sequences for the batch. : Shape: (B, D_dec), where B is the batch size and D_dec is the number of decoder units.
dec_states: Updated decoder hidden states. This is a tuple : containing the new states for each hypothesis: ((N, B, D_dec), (N, B, D_dec)).
lm_labels: Label ID sequences for the language model. : Shape: (B,), where B is the batch size. If use_lm is False, this will be None.
Return type: dec_out
################### Examples
>>> decoder = TransducerDecoder(vocab_size=1000)
>>> hyps = [Hypothesis(yseq=[1, 2, 3], dec_state=initial_state)]
>>> dec_states = decoder.init_state(batch_size=len(hyps))
>>> cache = {}
>>> dec_out, new_states, lm_labels = decoder.batch_score(
... hyps, dec_states, cache, use_lm=True
... )
>>> print(dec_out.shape) # (B, D_dec)
########### NOTE The method assumes that all hypotheses in the input list have been initialized properly and that their sequences are valid.
- Raises:
- ValueError – If the provided hypotheses or states are invalid
- or do not match the expected dimensions. –
create_batch_states(states: Tuple[Tensor, Tensor | None], new_states: List[Tuple[Tensor, Tensor | None]], check_list: List | None = None) → List[Tuple[Tensor, Tensor | None]]
Create decoder hidden states for a batch of hypotheses.
This method constructs the hidden states of the decoder based on the provided current states and a list of new states for each hypothesis. It concatenates the new states to form the complete hidden states for the batch.
- Parameters:
- states – Tuple containing the current decoder hidden states. The first element is of shape (N, B, D_dec) and the second (if LSTM) is of shape (N, B, D_dec) as well.
- new_states – List of tuples containing the new hidden states for each hypothesis, where each tuple is of the form ((1, D_dec), (1, D_dec)) for LSTM or just (1, D_dec) for GRU.
- check_list – Optional list for additional state checks (default: None).
- Returns:
- The first element of shape (N, B, D_dec).
- The second element (if LSTM) of shape (N, B, D_dec), else None.
- Return type: Tuple of concatenated decoder hidden states
################### Examples
>>> current_states = (torch.zeros(2, 3, 320), torch.zeros(2, 3, 320))
>>> new_hyp_states = [(torch.ones(1, 320), torch.ones(1, 320)),
... (torch.zeros(1, 320), torch.zeros(1, 320))]
>>> batch_states = create_batch_states(current_states, new_hyp_states)
>>> len(batch_states) # Output: 2
>>> batch_states[0].shape # Output: torch.Size([2, 3, 320])
########### NOTE Ensure that the new_states list matches the expected format based on the decoder type (LSTM or GRU).
forward(labels: Tensor) → Tensor
Encode source label sequences.
This method processes the input label sequences through the decoder network, applying embedding and RNN transformations to generate the decoder output sequences.
- Parameters:labels – Label ID sequences. Shape (B, L), where B is the batch size and L is the sequence length.
- Returns: Decoder output sequences. Shape (B, T, U, D_dec), where : T is the output length, U is the number of units, and D_dec is the dimension of the decoder.
- Return type: dec_out
################### Examples
>>> decoder = TransducerDecoder(vocab_size=1000)
>>> input_labels = torch.randint(0, 1000, (32, 10)) # Batch of 32
>>> output = decoder.forward(input_labels)
>>> output.shape
torch.Size([32, T, U, D_dec]) # Shape will depend on T and U
########### NOTE Ensure that the input labels are properly padded and contain valid IDs as per the embedding layer configuration.
init_state(batch_size: int) → Tuple[Tensor, tensor | None]
Initialize decoder states.
This method creates and initializes the hidden states for the decoder. The hidden states are essential for the operation of the recurrent neural network (RNN) used in the transducer decoder. Depending on the type of RNN (LSTM or GRU), the method will return either a tuple containing both hidden states and cell states (for LSTM) or just the hidden states (for GRU).
Parameters:batch_size – The number of sequences in a batch. This determines the size of the hidden states.
Returns:
- For LSTM: ((N, B, D_dec), (N, B, D_dec))
- For GRU: ((N, B, D_dec), None)
Where: : - N is the number of layers,
- B is the batch size,
- D_dec is the number of decoder units per layer.
Return type: A tuple containing the initialized hidden states
################### Examples
>>> decoder = TransducerDecoder(vocab_size=1000)
>>> h_n, c_n = decoder.init_state(batch_size=32)
>>> h_n.shape
torch.Size([num_layers, 32, 320])
>>> c_n.shape
torch.Size([num_layers, 32, 320]) # Only for LSTM
########### NOTE This method should be called before the decoder is used for generating predictions, ensuring that the initial hidden states are set correctly for each batch of sequences.
rnn_forward(sequence: Tensor, state: Tuple[Tensor, Tensor | None]) → Tuple[Tensor, Tuple[Tensor, Tensor | None]]
score(hyp: Hypothesis, cache: Dict[str, Any]) → Tuple[Tensor, Tuple[Tensor, Tensor | None], Tensor]
Compute the score for a single hypothesis.
This method performs a one-step forward pass for the given hypothesis using the decoder’s current state and caches the result for future use. It retrieves the decoder output and the new hidden states based on the last label in the hypothesis.
Parameters:
- hyp – The hypothesis containing the label sequence and current decoder state.
- cache – A dictionary that stores pairs of (dec_out, state) for each label sequence to avoid redundant computations.
Returns: The decoder output sequence for the current label. : Shape: (1, D_dec)
new_state: The updated decoder hidden states after processing : the input. Shape: ((N, 1, D_dec), (N, 1, D_dec))
label: The label ID for the language model. Shape: (1,)
Return type: dec_out
################### Examples
>>> hyp = Hypothesis(yseq=[2, 3, 4], dec_state=(h_n, c_n))
>>> cache = {}
>>> dec_out, new_state, label = decoder.score(hyp, cache)
########### NOTE This method assumes that the hypothesis has at least one label in its sequence.
- Raises:KeyError – If the hypothesis label sequence is not found in the cache and fails to generate a new output.
Get specified ID state from decoder hidden states.
This method retrieves the decoder hidden state for a specified index from the provided decoder hidden states. It is particularly useful in scenarios where multiple hypotheses are being processed in parallel, and you need to extract the hidden state corresponding to a specific hypothesis.
- Parameters:
- states – Decoder hidden states. A tuple containing two tensors: ((N, B, D_dec), (N, B, D_dec)), where N is the number of layers, B is the batch size, and D_dec is the dimension of the decoder.
- idx – State ID to extract. This is the index of the hidden state that you wish to retrieve from the decoder states.
- Returns: A tuple containing the decoder hidden state for the given ID. The output will be in the shape: ((N, 1, D_dec), (N, 1, D_dec)) for LSTM, or ((N, 1, D_dec), None) for GRU.
################### Examples
>>> decoder = TransducerDecoder(vocab_size=1000)
>>> states = decoder.init_state(batch_size=2)
>>> selected_state = decoder.select_state(states, idx=0)
>>> print(selected_state[0].shape) # Output: (N, 1, D_dec)
>>> print(selected_state[1].shape) # Output: (N, 1, D_dec) for LSTM
set_device(device: device)
Set GPU device to use.
This method updates the device attribute of the TransducerDecoder instance, allowing the model to operate on the specified GPU or CPU. It is important to set the device correctly to ensure that all tensor operations are performed on the desired hardware.
- Parameters:device – A torch.device object representing the device to be used. This can be a CPU or a specific GPU device (e.g., torch.device(“cuda:0”) for the first GPU).
################### Examples
>>> decoder = TransducerDecoder(vocab_size=100)
>>> decoder.set_device(torch.device("cuda:0"))
>>> print(decoder.device)
cuda:0
########### NOTE Ensure that the specified device is available on the system. Use torch.cuda.is_available() to check if CUDA is supported.
- Raises:ValueError – If the provided device is not a valid torch.device.