espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.MultiblankRNNTLossNumba
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.MultiblankRNNTLossNumba
class espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.MultiblankRNNTLossNumba(blank, big_blank_durations, reduction='mean', fastemit_lambda: float = 0.0, clamp: float = -1, sigma: float = 0.0)
Bases: Module
Multiblank RNNT Loss Numba.
This class implements the Multi-blank RNN Transducer loss using Numba for efficient computation. It is designed for use in automatic speech recognition tasks where multiple blank labels are utilized to improve the model’s performance. The loss is computed using a forward pass, and the backward pass computes gradients for optimization.
blank
Standard blank label.
- Type: int
big_blank_durations
List of durations for multi-blank transducer, e.g., [2, 4, 8].
- Type: list
sigma
Hyper-parameter for logit under-normalization method for training multi-blank transducers. Recommended value: 0.05.
- Type: float
reduction
Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. Default: ‘mean’.
- Type: str
fastemit_lambda
Float scaling factor for FastEmit regularization.
- Type: float
clamp
Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
Type: float
Parameters:
- blank (int) – Standard blank label.
- big_blank_durations (list) – List of durations for multi-blank transducer, e.g., [2, 4, 8].
- reduction (str , optional) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. Default: ‘mean’.
- fastemit_lambda (float , optional) – Float scaling factor for FastEmit regularization. Default: 0.0.
- clamp (float , optional) – Float value for gradient clamping. Default: -1.
- sigma (float , optional) – Hyper-parameter for logit under-normalization. Default: 0.0.
####### Examples
>>> loss_fn = MultiblankRNNTLossNumba(blank=0, big_blank_durations=[2, 4, 8])
>>> acts = torch.randn(10, 20, 30, 40) # Example tensor
>>> labels = torch.randint(0, 30, (10, 15))
>>> act_lens = torch.randint(1, 20, (10,))
>>> label_lens = torch.randint(1, 15, (10,))
>>> loss = loss_fn(acts, labels, act_lens, label_lens)
NOTE
Refer to the paper at https://arxiv.org/pdf/2211.03541 for detailed explanations regarding the multi-blank transducer and its parameters.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(acts, labels, act_lens, label_lens)
MultiblankRNNTNumba Forward.
Computes the forward pass of the multi-blank RNNT loss function.
This function takes in the model’s output, the target labels, and their respective lengths to compute the loss for the multi-blank RNNT. The function supports gradient calculation for backpropagation.
- Parameters:
- ctx – The context object that can be used to store information for backward computation.
- acts (torch.Tensor) – A tensor of shape (batch x seqLength x labelLength x outputDim) containing the output probabilities from the network.
- labels (torch.Tensor) – A 2D tensor containing the target labels for each element in the batch, zero-padded.
- act_lens (torch.Tensor) – A 1D tensor containing the lengths of each output sequence from the network (batch size).
- label_lens (torch.Tensor) – A 1D tensor containing the lengths of the label sequences for each example in the batch.
- blank (int) – The blank label used in the RNNT model.
- big_blank_durations (list) – A list of durations for multi-blank transducer, e.g. [2, 4, 8].
- reduction (str) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: ‘mean’.
- fastemit_lambda (float) – A scaling factor for FastEmit regularization. Refer to the FastEmit paper for details.
- clamp (float) – A float value that, when set to >= 0.0, will clamp the gradient to the range [-clamp, clamp].
- sigma (float) – A hyper-parameter for logit under-normalization method for training multi-blank transducers. Recommended value is 0.05.
- Returns: A tensor containing the computed costs for each example in the batch.
- Return type: torch.Tensor
- Raises:
- ValueError – If clamp is less than 0 or if input tensor shapes
- do not match the expected dimensions. –
####### Examples
>>> acts = torch.randn(2, 5, 3, 10) # (batch x seqLength x labelLength x outputDim)
>>> labels = torch.randint(0, 3, (2, 4)) # (batch x maxLabelLength)
>>> act_lens = torch.tensor([5, 5]) # lengths of output sequences
>>> label_lens = torch.tensor([4, 3]) # lengths of label sequences
>>> blank = 0
>>> big_blank_durations = [2, 4, 8]
>>> reduction = "mean"
>>> fastemit_lambda = 0.0
>>> clamp = 0.0
>>> sigma = 0.05
>>> costs = _MultiblankRNNTNumba.forward(ctx, acts, labels, act_lens,
... label_lens, blank, big_blank_durations, reduction, fastemit_lambda, clamp, sigma)
NOTE
This implementation relies on CUDA for performance; make sure to run this on a compatible GPU.