espnet2.asr.decoder.rnn_decoder.RNNDecoder
espnet2.asr.decoder.rnn_decoder.RNNDecoder
class espnet2.asr.decoder.rnn_decoder.RNNDecoder(vocab_size: int, encoder_output_size: int, rnn_type: str = 'lstm', num_layers: int = 1, hidden_size: int = 320, sampling_probability: float = 0.0, dropout: float = 0.0, context_residual: bool = False, replace_sos: bool = False, num_encs: int = 1, att_conf: dict = {'aconv_chans': 10, 'aconv_filts': 100, 'adim': 320, 'aheads': 4, 'atype': 'location', 'awin': 5, 'han_conv_chans': -1, 'han_conv_filts': 100, 'han_dim': 320, 'han_heads': 4, 'han_mode': False, 'han_type': None, 'han_win': 5, 'num_att': 1, 'num_encs': 1})
Bases: AbsDecoder
RNNDecoder is a recurrent neural network (RNN) based decoder for automatic speech recognition (ASR). It is designed to convert encoded representations from an encoder into a sequence of output tokens using attention mechanisms and recurrent layers.
vocab_size
The size of the vocabulary.
- Type: int
encoder_output_size
The size of the encoder output.
- Type: int
rnn_type
The type of RNN to use, either ‘lstm’ or ‘gru’.
- Type: str
num_layers
The number of recurrent layers in the decoder.
- Type: int
hidden_size
The size of the hidden layers in the RNN.
- Type: int
sampling_probability
The probability of using sampling during decoding.
- Type: float
dropout
The dropout rate for regularization.
- Type: float
context_residual
Whether to use context residual connections.
- Type: bool
replace_sos
Whether to replace the start of sequence token.
- Type: bool
num_encs
The number of encoders to support.
- Type: int
att_list
A list of attention modules for decoding.
Type: ModuleList
Parameters:
- vocab_size (int) – Size of the vocabulary.
- encoder_output_size (int) – Size of the encoder output.
- rnn_type (str , optional) – Type of RNN (‘lstm’ or ‘gru’). Defaults to ‘lstm’.
- num_layers (int , optional) – Number of layers in the RNN. Defaults to 1.
- hidden_size (int , optional) – Size of hidden units. Defaults to 320.
- sampling_probability (float , optional) – Probability for sampling. Defaults to 0.0.
- dropout (float , optional) – Dropout rate. Defaults to 0.0.
- context_residual (bool , optional) – Use context residual connections. Defaults to False.
- replace_sos (bool , optional) – Replace the start of sequence token. Defaults to False.
- num_encs (int , optional) – Number of encoders. Defaults to 1.
- att_conf (dict , optional) – Configuration for attention. Defaults to built-in configuration.
Returns: None
Raises:ValueError – If rnn_type is not ‘lstm’ or ‘gru’.
############### Examples
Initialize the decoder
decoder = RNNDecoder(
vocab_size=5000, encoder_output_size=256, rnn_type=’lstm’, num_layers=2, hidden_size=512, sampling_probability=0.1, dropout=0.2
)
Forward pass through the decoder
output, lengths = decoder(hs_pad, hlens, ys_in_pad, ys_in_lens)
########## NOTE This class supports multiple encoders and can be used for multilingual translation tasks.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0)
Performs the forward pass of the RNN decoder.
This method takes the padded hidden states from the encoder and the previous output sequences, processes them through the RNN layers, and computes the logits for the next output sequence. It supports both single and multiple encoder modes.
- Parameters:
- hs_pad (torch.Tensor or List *[*torch.Tensor ]) – The padded hidden states from the encoder. If in multiple encoder mode, this should be a list of tensors, one for each encoder.
- hlens (torch.Tensor or List *[*torch.Tensor ]) – The lengths of the hidden states corresponding to hs_pad. Should match the format of hs_pad.
- ys_in_pad (torch.Tensor) – The input sequences (previous outputs), padded to the maximum length.
- ys_in_lens (torch.Tensor) – The lengths of the input sequences.
- strm_idx (int , optional) – The index of the stream (encoder) to use for attention. Defaults to 0.
- Returns: A tuple containing: : - Logits for the output sequence of shape (batch_size, max_length, <br/> vocab_size).
- The lengths of the input sequences after processing.
- Return type: Tuple[torch.Tensor, torch.Tensor]
- Raises:ValueError – If the number of encoders is less than one.
############### Examples
>>> decoder = RNNDecoder(vocab_size=5000, encoder_output_size=256)
>>> hs_pad = torch.randn(10, 20, 256) # Batch of 10, max length 20
>>> hlens = torch.randint(1, 21, (10,)) # Random lengths
>>> ys_in_pad = torch.randint(0, 5000, (10, 15)) # Previous outputs
>>> ys_in_lens = torch.randint(1, 16, (10,)) # Random lengths
>>> logits, new_lengths = decoder.forward(hs_pad, hlens, ys_in_pad, ys_in_lens)
########## NOTE This function handles both single and multi-encoder scenarios. In the case of multiple encoders, the attention mechanism will select the appropriate encoder based on the provided strm_idx.
init_state(x)
Initializes the hidden and cell states for the RNN decoder.
This method prepares the initial states required for the recurrent neural network (RNN) decoder. It supports multiple encoder configurations and ensures that the initial states are correctly shaped for processing.
c_prev
A list containing the initial cell states for each RNN layer.
- Type: list
z_prev
A list containing the initial hidden states for each RNN layer.
- Type: list
a_prev
A list containing the initial attention states for each encoder, or None if there is only one encoder.
- Type: list or None
workspace
A tuple containing the attention index and the lists of hidden and cell states.
Type: tuple
Parameters:x (torch.Tensor) – The input tensor from which to derive the initial states. Its shape should be (batch_size, encoder_output_size).
Returns: A dictionary containing the initialized states and workspace.
Return type: dict
############### Examples
>>> decoder = RNNDecoder(vocab_size=1000, encoder_output_size=512)
>>> x = torch.randn(32, 512) # Example input for a batch of size 32
>>> states = decoder.init_state(x)
>>> print(states['c_prev']) # Should print the initialized cell states
>>> print(states['z_prev']) # Should print the initialized hidden states
########## NOTE This method is primarily used during the decoding process, where it sets up the initial states before generating outputs from the decoder.
rnn_forward(ey, z_list, c_list, z_prev, c_prev)
Performs a forward pass through the RNN layers.
This method processes the input embedding ey and updates the hidden state and cell state of the RNN layers based on the specified RNN type (LSTM or GRU). It handles multiple layers of RNNs, applying dropout to the outputs of the previous layer.
- Parameters:
- ey (torch.Tensor) – The input tensor of shape (batch_size, input_size), where input_size is the combined size of the embedding and attention context.
- z_list (list) – A list of tensors containing the hidden states of each RNN layer, each of shape (batch_size, hidden_size).
- c_list (list) – A list of tensors containing the cell states of each LSTM layer, each of shape (batch_size, hidden_size). This argument is ignored for GRU.
- z_prev (list) – A list of tensors representing the previous hidden states for each RNN layer.
- c_prev (list) – A list of tensors representing the previous cell states for each LSTM layer, this argument is ignored for GRU.
- Returns: A tuple containing: : - z_list (list): The updated hidden states for each RNN layer.
- c_list (list): The updated cell states for each LSTM layer.
- Return type: tuple
############### Examples
>>> rnn_decoder = RNNDecoder(vocab_size=100, encoder_output_size=256)
>>> ey = torch.randn(32, 256) # Example input tensor
>>> z_list = [torch.zeros(32, 320) for _ in range(2)] # Hidden states
>>> c_list = [torch.zeros(32, 320) for _ in range(2)] # Cell states
>>> z_prev = [torch.zeros(32, 320) for _ in range(2)] # Previous hidden
>>> c_prev = [torch.zeros(32, 320) for _ in range(2)] # Previous cell
>>> z_list, c_list = rnn_decoder.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
########## NOTE
- The method handles both LSTM and GRU architectures seamlessly.
- Dropout is applied to the outputs of the previous layer, controlled
by the dropout parameter during initialization.
- Raises:
- ValueError – If the input tensor dimensions do not match the expected
- shape. –
score(yseq, state, x)
Calculate the log probabilities of the next token in a sequence.
This method computes the log probabilities of the next token given the previous tokens, the current state, and the encoder outputs. It uses the recurrent neural network (RNN) to process the input and apply attention mechanisms if multiple encoders are utilized.
- Parameters:
- yseq (torch.Tensor) – A tensor containing the sequence of tokens (shape: (T,)) where T is the sequence length.
- state (dict) – A dictionary containing the current state, which includes the previous hidden states and the attention context.
- x (torch.Tensor) – A tensor containing the encoder outputs, where the shape is (B, E), B is the batch size and E is the encoder output size.
- Returns: A tuple containing: : - logp (torch.Tensor): Log probabilities of the next token <br/> (shape: (vocab_size,)).
- state (dict): A dictionary containing the updated state with previous hidden states and attention weights.
- Return type: tuple
############### Examples
>>> decoder = RNNDecoder(vocab_size=100, encoder_output_size=256)
>>> yseq = torch.tensor([1, 2, 3]) # Example sequence of tokens
>>> state = decoder.init_state(x)
>>> x = torch.randn(1, 256) # Example encoder output
>>> logp, new_state = decoder.score(yseq, state, x)
>>> print(logp.shape) # Should output: torch.Size([100])
########## NOTE This method supports both single and multiple encoder modes. In single encoder mode, the encoder output is directly used. In multiple encoder mode, attention weights are computed for each encoder output.
- Raises:ValueError – If the number of encoders is less than one.
zero_state(hs_pad)
Initialize the hidden state of the RNN decoder.
This method creates a zero-filled tensor to be used as the initial hidden state for the RNN cells in the decoder. The size of the tensor matches the batch size of the input tensor, with the second dimension being equal to the number of hidden units in the RNN.
- Parameters:hs_pad (torch.Tensor) – A tensor of shape (batch_size, hidden_size) from which the batch size is inferred.
- Returns: A tensor of shape (batch_size, hidden_size) filled with zeros, which represents the initial hidden state.
- Return type: torch.Tensor
############### Examples
>>> decoder = RNNDecoder(vocab_size=100, encoder_output_size=256)
>>> hs_pad = torch.randn(32, 256) # Example input
>>> initial_state = decoder.zero_state(hs_pad)
>>> print(initial_state.shape)
torch.Size([32, 320]) # Assuming hidden_size is 320
########## NOTE This method is primarily used to initialize the hidden state for the first step of the decoding process in RNNs.