espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank._MultiblankRNNTNumba
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank._MultiblankRNNTNumba
class espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank._MultiblankRNNTNumba(*args, **kwargs)
Bases: Function
Numba class for multi-blank RNN Transducer loss.
This class implements the forward and backward passes for a multi-blank RNN Transducer (RNNT) loss function using Numba. The multi-blank RNNT allows for multiple blank durations during training, which can improve the performance of automatic speech recognition models by better accommodating varying input lengths.
References
None
- Parameters:
- ctx – The context object for storing information for the backward pass.
- acts (torch.Tensor) – A tensor of shape (batch x seqLength x labelLength x outputDim) containing the output from the network.
- labels (torch.Tensor) – A 2D tensor containing the targets for the batch, zero-padded.
- act_lens (torch.Tensor) – A tensor of size (batch) containing the size of each output sequence from the network.
- label_lens (torch.Tensor) – A tensor of size (batch) containing the label lengths for each example.
- blank (int) – The index of the blank label.
- big_blank_durations (list) – A list of durations for the multi-blank transducer, e.g., [2, 4, 8].
- reduction (str) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. Default is ‘mean’.
- fastemit_lambda (float) – Scaling factor for FastEmit regularization.
- clamp (float) – A non-negative float value for clamping gradients. Must be >= 0.
- sigma (float) – Hyper-parameter for logit under-normalization method for training multi-blank transducers. Recommended value is 0.05.
- Returns: The computed costs as a tensor.
- Return type: torch.Tensor
- Raises:
- ValueError – If clamp is negative or if any of the tensor shapes
- do not match the expected dimensions. –
######### Examples
Example usage of the forward method:
loss = _MultiblankRNNTNumba.forward(
ctx, acts, labels, act_lens, label_lens, blank=0, big_blank_durations=[2, 4, 8], reduction=’mean’, fastemit_lambda=0.0, clamp=0.0, sigma=0.05
)
NOTE
This implementation is specifically designed for GPU usage. If inputs are on CPU, consider using the appropriate CPU loss functions instead.
static backward(ctx, grad_output)
Compute the gradients for the backward pass of the RNNT loss.
This method computes the gradients of the loss with respect to the inputs of the forward pass. It uses the stored gradients from the forward context to compute the final gradient output.
- Parameters:
- ctx – The context object containing information from the forward pass, including stored gradients.
- grad_output – The gradient of the loss with respect to the output of the forward pass. This is typically provided by PyTorch during the backward pass.
- Returns: A tuple containing the gradients with respect to the inputs of the forward pass. The inputs that are not relevant for the gradient are returned as None.
- Return type: Tuple[Tensor, None, None, None, None, None, None, None]
######### Examples
Example usage in a PyTorch training loop
loss = RNNTLossNumba() output = model(input) loss_value = loss(output, labels, act_lens, label_lens) loss_value.backward() # This calls the backward method to compute gradients
NOTE
Ensure that grad_output is not None and ctx.grads is not None before attempting to compute the gradients. If either is None, the function will not perform any operations.
static forward(ctx, acts, labels, act_lens, label_lens, blank, big_blank_durations, reduction, fastemit_lambda, clamp, sigma)
MultiblankRNNTNumba Forward.
This method computes the forward pass for the multi-blank RNNT loss using Numba for performance optimization. It takes the predicted logits, target labels, and their respective lengths to calculate the loss values.
- Parameters:
- ctx – The context object to save information for backward pass.
- acts (Tensor) – Tensor of shape (batch x seqLength x labelLength x outputDim) containing output logits from the network.
- labels (Tensor) – 2D tensor containing target labels for the batch, zero-padded as necessary.
- act_lens (Tensor) – 1D tensor of size (batch) containing the lengths of each output sequence from the network.
- label_lens (Tensor) – 1D tensor of size (batch) containing the lengths of the labels for each example.
- blank (int) – The blank label used in RNNT. Default is 0.
- big_blank_durations (list) – List of durations for multi-blank transducer, e.g., [2, 4, 8].
- reduction (str) – Specifies the reduction method to apply to the output: ‘none’, ‘mean’, or ‘sum’. Default is ‘mean’.
- fastemit_lambda (float) – Scaling factor for FastEmit regularization. Refer to the FastEmit paper for more details.
- clamp (float) – Value to clamp the gradients to, must be >= 0.0.
- sigma (float) – Hyper-parameter for logit under-normalization method for training multi-blank transducers. Recommended value is 0.05.
- Returns: A tensor containing the computed loss for each example in the batch, reduced according to the specified method.
- Return type: Tensor
- Raises:
- ValueError – If clamp is negative, or if the input tensors are
- not of the expected dimensions or types. –
######### Examples
>>> acts = torch.randn(4, 10, 5, 20) # Example logits
>>> labels = torch.tensor([[1, 2, 3, 0], [1, 0, 0, 0],
... [2, 3, 0, 0], [0, 0, 0, 0]])
>>> act_lens = torch.tensor([10, 5, 10, 0])
>>> label_lens = torch.tensor([3, 1, 2, 0])
>>> blank = 0
>>> big_blank_durations = [2, 4, 8]
>>> reduction = "mean"
>>> fastemit_lambda = 0.0
>>> clamp = 0.0
>>> sigma = 0.05
>>> loss = _MultiblankRNNTNumba.forward(
... None, acts, labels, act_lens, label_lens,
... blank, big_blank_durations, reduction,
... fastemit_lambda, clamp, sigma)