espnet2.asr.ctc.CTC
espnet2.asr.ctc.CTC
class espnet2.asr.ctc.CTC(odim: int, encoder_output_size: int, dropout_rate: float = 0.0, ctc_type: str = 'builtin', reduce: bool = True, ignore_nan_grad: bool | None = None, zero_infinity: bool = True, brctc_risk_strategy: str = 'exp', brctc_group_strategy: str = 'end', brctc_risk_factor: float = 0.0)
Bases: Module
Connectionist Temporal Classification (CTC) module for sequence-to-sequence tasks.
This module implements the CTC loss function, which is commonly used in automatic speech recognition and other sequence prediction tasks where the alignment between input and output sequences is unknown.
- Parameters:
- odim (int) – Dimension of outputs (vocabulary size).
- encoder_output_size (int) – Number of encoder projection units.
- dropout_rate (float , optional) – Dropout rate (0.0 ~ 1.0). Default is 0.0.
- ctc_type (str , optional) – Type of CTC loss to use. Options are “builtin”, “builtin2”, “gtnctc”, or “brctc”. Default is “builtin”.
- reduce (bool , optional) – Whether to reduce the CTC loss into a scalar. Default is True.
- ignore_nan_grad (Optional *[*bool ] , optional) – If set to True, NaN gradients are ignored. This is kept for backward compatibility. Default is None.
- zero_infinity (bool , optional) – Whether to zero infinite losses and the associated gradients. Default is True.
- brctc_risk_strategy (str , optional) – Risk strategy for Bayes Risk CTC. Default is “exp”.
- brctc_group_strategy (str , optional) – Group strategy for Bayes Risk CTC. Default is “end”.
- brctc_risk_factor (float , optional) – Risk factor for Bayes Risk CTC. Default is 0.0.
- Raises:
- ValueError – If an invalid ctc_type is provided.
- ImportError – If “brctc” is selected but the K2 library is not installed.
############### Examples
>>> ctc = CTC(odim=10, encoder_output_size=64)
>>> hs_pad = torch.randn(32, 100, 64) # (B, Tmax, D)
>>> hlens = torch.randint(1, 100, (32,)) # Lengths of hidden states
>>> ys_pad = torch.randint(0, 10, (32, 50)) # Padded target sequences
>>> ys_lens = torch.randint(1, 50, (32,)) # Lengths of target sequences
>>> loss = ctc(hs_pad, hlens, ys_pad, ys_lens)
######### NOTE The “builtin” and “builtin2” types use PyTorch’s built-in CTC loss implementation, while “gtnctc” and “brctc” require additional libraries for their respective functionalities.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
argmax(hs_pad)
Compute the argmax of frame activations.
This method applies the CTC linear layer to the input tensor and computes the argmax across the output dimension. The input should be a 3D tensor representing the batch of padded hidden state sequences.
- Parameters:hs_pad (torch.Tensor) – A 3D tensor of shape (B, Tmax, eprojs), where B is the batch size, Tmax is the maximum sequence length, and eprojs is the number of encoder projection units.
- Returns: A 2D tensor of shape (B, Tmax) containing the : indices of the maximum values along the output dimension for each time step.
- Return type: torch.Tensor
############### Examples
>>> ctc = CTC(odim=10, encoder_output_size=20)
>>> hs_pad = torch.randn(4, 5, 20) # Example input
>>> argmax_output = ctc.argmax(hs_pad)
>>> print(argmax_output.shape)
torch.Size([4, 5]) # Output shape will be (B, Tmax)
######### NOTE This function is useful for decoding the predicted output sequences from the model after training.
forward(hs_pad, hlens, ys_pad, ys_lens)
Calculate the Connectionist Temporal Classification (CTC) loss.
This method computes the CTC loss given the padded hidden state sequences and the corresponding target sequences. It handles various types of CTC losses based on the ctc_type specified during initialization.
- Parameters:
- hs_pad (torch.Tensor) – A batch of padded hidden state sequences with shape (B, Tmax, D), where B is the batch size, Tmax is the maximum sequence length, and D is the number of features.
- hlens (torch.Tensor) – A tensor containing the lengths of the hidden state sequences with shape (B).
- ys_pad (torch.Tensor) – A batch of padded character ID sequences with shape (B, Lmax), where Lmax is the maximum target sequence length.
- ys_lens (torch.Tensor) – A tensor containing the lengths of the character sequences with shape (B).
- Returns: The computed CTC loss as a tensor. The loss is returned in the same device and data type as the input hidden states.
- Return type: torch.Tensor
############### Examples
>>> ctc = CTC(odim=10, encoder_output_size=20)
>>> hs_pad = torch.randn(32, 50, 20) # Example hidden states
>>> hlens = torch.randint(1, 51, (32,)) # Example lengths
>>> ys_pad = torch.randint(0, 10, (32, 30)) # Example targets
>>> ys_lens = torch.randint(1, 31, (32,)) # Example target lengths
>>> loss = ctc(hs_pad, hlens, ys_pad, ys_lens)
>>> print(loss)
######### NOTE Ensure that the input tensors are appropriately padded and have the correct shapes as specified in the arguments.
- Raises:NotImplementedError – If the ctc_type is not supported.
log_softmax(hs_pad)
Computes the log softmax of frame activations.
This function applies the log softmax function to the output of the CTC layer, transforming the raw scores (logits) into log-probabilities. The log softmax function is particularly useful in the context of neural networks as it helps in numerical stability during training.
- Parameters:hs_pad (torch.Tensor) – A 3D tensor of shape (B, Tmax, eprojs), where B is the batch size, Tmax is the maximum time steps, and eprojs is the number of encoder projection units.
- Returns: A 3D tensor of shape (B, Tmax, odim) after applying the log softmax function, where odim is the dimension of outputs.
- Return type: torch.Tensor
############### Examples
>>> ctc = CTC(odim=10, encoder_output_size=5)
>>> hs_pad = torch.rand(2, 3, 5) # Example input tensor
>>> log_probs = ctc.log_softmax(hs_pad)
>>> log_probs.shape
torch.Size([2, 3, 10])
######### NOTE The output of this function can be used as input to the CTC loss function to compute the loss during training.
loss_fn(th_pred, th_target, th_ilen, th_olen) → Tensor
Compute the CTC loss for the given predictions and targets.
This function calculates the Connectionist Temporal Classification (CTC) loss between the predicted logits and the target sequences. It handles various types of CTC loss implementations based on the ctc_type specified during the initialization of the CTC module.
- Parameters:
- th_pred (torch.Tensor) – The predicted logits from the model. Shape should be (B, L, O), where B is the batch size, L is the length of the sequences, and O is the number of output classes.
- th_target (torch.Tensor) – The target sequences of character IDs. Shape should be (N,), where N is the total number of target characters across the batch.
- th_ilen (torch.Tensor) – The lengths of the predicted sequences. Shape should be (B,).
- th_olen (torch.Tensor) – The lengths of the target sequences. Shape should be (B,).
- Returns: The computed CTC loss value. The shape depends on the reduce attribute; if reduce is True, the loss is a scalar, else it retains the shape corresponding to the number of valid sequences.
- Return type: torch.Tensor
- Raises:
- NotImplementedError – If the ctc_type is not recognized.
- ValueError – If ctc_type is neither “builtin” nor “gtnctc”.
############### Examples
>>> ctc = CTC(odim=10, encoder_output_size=5)
>>> th_pred = torch.randn(3, 4, 10) # Example logits for batch size 3
>>> th_target = torch.tensor([1, 2, 3]) # Example target character IDs
>>> th_ilen = torch.tensor([4, 4, 4]) # All sequences have length 4
>>> th_olen = torch.tensor([1, 1, 1]) # All targets have length 1
>>> loss = ctc.loss_fn(th_pred, th_target, th_ilen, th_olen)
>>> print(loss)
######### NOTE The function can handle NaN gradients based on the ignore_nan_grad attribute, which can help avoid issues during training.
softmax(hs_pad)
Compute the softmax of frame activations.
This method applies the softmax function to the output of the linear layer, converting raw logits into probabilities. It is particularly useful for interpreting the model’s output in the context of classification tasks.
- Parameters:hs_pad (torch.Tensor) – A 3D tensor of shape (B, Tmax, eprojs) where B is the batch size, Tmax is the maximum time steps, and eprojs is the number of encoder projection units.
- Returns: A 3D tensor of shape (B, Tmax, odim) containing the softmax probabilities for each output dimension (odim).
- Return type: torch.Tensor
############### Examples
>>> ctc = CTC(odim=10, encoder_output_size=20)
>>> hs_pad = torch.randn(5, 15, 20) # Example input
>>> softmax_output = ctc.softmax(hs_pad)
>>> print(softmax_output.shape) # Output: torch.Size([5, 15, 10])