espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.multiblank_rnnt_loss
About 1 min
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.multiblank_rnnt_loss
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.multiblank_rnnt_loss(acts, labels, act_lens, label_lens, blank, big_blank_durations=[], reduction='mean', fastemit_lambda: float = 0.0, clamp: float = 0.0)
Compute the multi-blank RNN Transducer loss.
This loss function is designed for training models using a multi-blank RNN transducer approach, as described in the paper “Multi-blank RNN Transducer” (https://arxiv.org/pdf/2211.03541.pdf).
- Parameters:
- acts (torch.Tensor) – A tensor of shape (batch x seqLength x labelLength x outputDim) containing the output from the network.
- labels (torch.Tensor) – A 2D tensor containing all the targets of the batch, zero-padded.
- act_lens (torch.Tensor) – A tensor of size (batch) containing the size of each output sequence from the network.
- label_lens (torch.Tensor) – A tensor of size (batch) containing the label length of each example.
- blank (int) – The standard blank label used in the loss calculation.
- big_blank_durations (list of int , optional) – A list of durations for the multi-blank transducer, e.g., [2, 4, 8]. Default is an empty list.
- reduction (str , optional) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. Default is ‘mean’.
- fastemit_lambda (float , optional) – A float scaling factor for FastEmit regularization. Default is 0.0.
- clamp (float , optional) – A float value. When set to a value >= 0.0, it will clamp the gradient to [-clamp, clamp]. Default is 0.0.
- Returns: The computed loss value.
- Return type: torch.Tensor
- Raises:
- ValueError – If the input tensors do not have the expected shapes or types.
- NotImplementedError – If the function is called with non-CUDA tensors.
Examples
>>> acts = torch.rand(32, 100, 20, 50) # Example output from network
>>> labels = torch.randint(0, 20, (32, 15)) # Example labels
>>> act_lens = torch.randint(1, 100, (32,)) # Lengths of acts
>>> label_lens = torch.randint(1, 15, (32,)) # Lengths of labels
>>> loss = multiblank_rnnt_loss(acts, labels, act_lens, label_lens,
... blank=0, big_blank_durations=[2, 4],
... reduction='mean')
>>> print(loss)