espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.RNNTLossNumba
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.RNNTLossNumba
class espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.RNNTLossNumba(blank=0, reduction='mean', fastemit_lambda: float = 0.0, clamp: float = -1)
Bases: Module
RNNTLossNumba is a PyTorch module that computes the RNN Transducer Loss using Numba.
This module leverages the Numba JIT compiler to optimize the forward and backward pass for RNN Transducer Loss, making it suitable for high-performance applications in automatic speech recognition (ASR) tasks.
blank
Standard blank label. Default is 0.
- Type: int
reduction
Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. Default is ‘mean’.
- Type: str
fastemit_lambda
Scaling factor for FastEmit regularization.
- Type: float
clamp
When set to a value >= 0.0, clamps the gradient to [-clamp, clamp].
Type: float
Parameters:
- blank (int , optional) – Standard blank label. Default is 0.
- reduction (str , optional) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. Default is ‘mean’.
- fastemit_lambda (float , optional) – Scaling factor for FastEmit regularization. Default is 0.0.
- clamp (float , optional) – Clamping value for gradients. Default is -1.
####### Examples
>>> import torch
>>> loss_fn = RNNTLossNumba(blank=0, reduction='mean')
>>> acts = torch.randn(10, 20, 15, 30) # (batch, seqLength, labelLength, outputDim)
>>> labels = torch.randint(0, 15, (10, 5)) # (batch, labelLength)
>>> act_lens = torch.randint(1, 21, (10,)) # (batch)
>>> label_lens = torch.randint(1, 6, (10,)) # (batch)
>>> loss = loss_fn(acts, labels, act_lens, label_lens)
>>> print(loss)
NOTE
Ensure that the input tensors are contiguous and of the correct type before passing them to the forward method.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(acts, labels, act_lens, label_lens)
Compute the forward pass of the RNNT loss.
This method computes the RNN Transducer (RNNT) loss based on the input activations (log probabilities) and target labels. It supports both GPU and CPU operations. The function returns the computed loss value(s) based on the specified reduction method.
- Parameters:
- ctx – The context object that can be used to store information for backward computation.
- acts (Tensor) – A tensor of shape (batch x seqLength x labelLength x outputDim) containing the output from the network.
- labels (Tensor) – A 2-dimensional tensor containing all the targets of the batch, zero-padded.
- act_lens (Tensor) – A tensor of size (batch) containing the sizes of each output sequence from the network.
- label_lens (Tensor) – A tensor of size (batch) containing the label lengths of each example.
- blank (int) – The index of the blank label. Default is 0.
- reduction (str) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the output losses will be divided by the target lengths and then the mean over the batch is taken. Default is ‘mean’.
- fastemit_lambda (float) – A scaling factor for FastEmit regularization. Refer to the FastEmit paper for more details.
- clamp (float) – A float value to clamp the gradient to [-clamp, clamp]. Must be >= 0.0.
- Returns: The computed loss value(s), either as a single value or as a tensor of values based on the reduction method specified.
- Return type: Tensor
- Raises:
- ValueError – If clamp is negative or if the input tensors
- do not meet dimensionality or type requirements. –
####### Examples
>>> acts = torch.rand(32, 100, 20, 256) # Example logits
>>> labels = torch.randint(0, 20, (32, 10)) # Example labels
>>> act_lens = torch.full((32,), 100, dtype=torch.int32)
>>> label_lens = torch.randint(1, 11, (32,), dtype=torch.int32)
>>> loss = _RNNTNumba.forward(
... None, acts, labels, act_lens, label_lens,
... blank=0, reduction='mean',
... fastemit_lambda=0.0, clamp=0.0
... )
>>> print(loss)
NOTE
This function is optimized for use with both CUDA and CPU tensors, and it is important to ensure that the inputs are correctly formatted to avoid runtime errors.