espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank._RNNTNumba
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank._RNNTNumba
class espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank._RNNTNumba(*args, **kwargs)
Bases: Function
RNNT Loss Numba class for computing RNN Transducer loss using Numba.
This class implements the forward and backward passes for the RNNT loss computation, leveraging Numba for performance optimizations. It takes as input the output activations from a neural network and the corresponding labels to compute the loss and gradients.
None
- Parameters:
- ctx – The context object that can be used to store information for the backward pass.
- acts (Tensor) – Tensor of shape (batch x seqLength x labelLength x outputDim) containing the output from the network.
- labels (Tensor) – 2D Tensor containing all the targets of the batch with zero padding.
- act_lens (Tensor) – 1D Tensor of size (batch) containing the size of each output sequence from the network.
- label_lens (Tensor) – 1D Tensor of size (batch) containing the label lengths of each example.
- blank (int) – The label index for the blank token in the RNNT loss.
- reduction (str) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. Default: ‘mean’.
- fastemit_lambda (float) – Scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
- clamp (float) – Value to clamp gradients to. Must be non-negative.
- Returns: The computed loss values for the input batch.
- Return type: Tensor
- Raises:ValueError – If clamp is less than 0.0.
######### Examples
Example usage:
acts = torch.randn(2, 10, 5, 20) # (batch x seqLength x labelLength x outputDim) labels = torch.randint(0, 5, (2, 3)) # (batch x maxLabelLength) act_lens = torch.tensor([10, 8]) # Lengths of each output sequence label_lens = torch.tensor([3, 2]) # Lengths of each label sequence loss = _RNNTNumba.apply(acts, labels, act_lens, label_lens,
blank=0, reduction=’mean’, fastemit_lambda=0.0, clamp=0.0)
static backward(ctx, grad_output)
Backward pass for the RNNTNumba function.
This method computes the gradients of the RNNT loss with respect to the input activations. It uses the stored gradients from the forward pass to apply the chain rule, allowing for backpropagation through the network.
- Parameters:
- ctx – The context object that contains information from the forward pass.
- grad_output – A tensor containing the gradient of the loss with respect to the output of the RNNT loss function.
- Returns: A tuple containing the gradients of the input activations and None for the other parameters that do not require gradients.
NOTE
This method assumes that the gradient output and stored gradients are not None. If grad_output is None, it will not compute any gradients.
######### Examples
>>> # Assuming acts, labels, act_lens, label_lens, and other required
>>> # parameters have been defined and used in the forward pass
>>> loss_fn = _RNNTNumba.apply(acts, labels, act_lens, label_lens)
>>> # Now compute gradients during the backward pass
>>> loss_fn.backward(grad_output)
static forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_lambda, clamp)
Compute the forward pass of the RNNT loss using Numba.
This method computes the RNN Transducer (RNNT) loss for a given set of inputs. The RNNT loss is useful for training sequence-to-sequence models, particularly in automatic speech recognition tasks.
- Parameters:
- ctx – Context object that can be used to store information for backward computation.
- acts (Tensor) – A tensor of shape (batch x seqLength x labelLength x outputDim) containing the output probabilities from the network.
- labels (Tensor) – A 2-dimensional tensor containing the target labels for each example in the batch, padded with zeros.
- act_lens (Tensor) – A tensor of size (batch) that indicates the actual length of each output sequence from the network.
- label_lens (Tensor) – A tensor of size (batch) that contains the length of each target label sequence.
- blank (int) – The index of the blank label. Default is 0.
- reduction (str) – Specifies the reduction method to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the output losses will be divided by the number of target lengths and then the mean over the batch is taken. Default is ‘mean’.
- fastemit_lambda (float) – A scaling factor for FastEmit regularization. Refer to the FastEmit paper for details.
- clamp (float) – A value to clamp the gradients. Must be 0.0 or greater.
- Returns: A tensor containing the computed RNNT loss for the batch.
- Return type: Tensor
- Raises:ValueError – If clamp is negative.
######### Examples
>>> acts = torch.randn(2, 10, 5, 20) # Example output from the network
>>> labels = torch.tensor([[1, 2, 0], [1, 0, 0]]) # Padded labels
>>> act_lens = torch.tensor([10, 5]) # Actual lengths of acts
>>> label_lens = torch.tensor([2, 1]) # Actual lengths of labels
>>> loss = _RNNTNumba.forward(None, acts, labels, act_lens, label_lens,
... blank=0, reduction='mean',
... fastemit_lambda=0.0, clamp=0.0)
>>> print(loss)