espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.rnnt_loss
About 1 min
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.rnnt_loss
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction='mean', fastemit_lambda: float = 0.0, clamp: float = 0.0)
RNN Transducer Loss (functional form).
This function computes the RNN Transducer loss for a given batch of inputs and targets. The loss can be reduced using different methods specified by the reduction parameter.
- Parameters:
- acts (Tensor) – A tensor of shape (batch x seqLength x labelLength x outputDim) containing output from the network.
- labels (Tensor) – A 2-dimensional tensor containing all the targets of the batch with zero padding.
- act_lens (Tensor) – A tensor of size (batch) containing the size of each output sequence from the network.
- label_lens (Tensor) – A tensor of size (batch) containing the label length of each example.
- blank (int , optional) – The blank label. Default is 0.
- reduction (str , optional) – Specifies the reduction to apply to the output: ‘none’, ‘mean’, or ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the output losses will be divided by the target lengths and then the mean over the batch is taken. Default is ‘mean’.
- fastemit_lambda (float , optional) – A scaling factor for FastEmit regularization. Default is 0.0.
- clamp (float , optional) – A float value to clamp the gradient to [-clamp, clamp]. Default is 0.0.
- Returns: The computed loss value after applying the specified reduction.
- Return type: Tensor
- Raises:ValueError – If clamp is negative or if the input dimensions do not match the expected shapes.
Examples
>>> acts = torch.randn(2, 5, 10, 20) # Example logits
>>> labels = torch.randint(0, 10, (2, 8)) # Example labels
>>> act_lens = torch.tensor([5, 5]) # Lengths of output sequences
>>> label_lens = torch.tensor([8, 6]) # Lengths of label sequences
>>> loss = rnnt_loss(acts, labels, act_lens, label_lens)
>>> print(loss)
NOTE
This function requires the input tensors to be contiguous and of the correct dtype. Ensure that the input tensors are properly formatted before calling this function.