espnet2.asr.transducer.rnnt_multi_blank.rnnt.multiblank_rnnt_loss_gpu
About 1 min
espnet2.asr.transducer.rnnt_multi_blank.rnnt.multiblank_rnnt_loss_gpu
espnet2.asr.transducer.rnnt_multi_blank.rnnt.multiblank_rnnt_loss_gpu(acts: Tensor, labels: Tensor, input_lengths: Tensor, label_lengths: Tensor, costs: Tensor, grads: Tensor, blank_label: int, big_blank_durations: list, fastemit_lambda: float, clamp: float, num_threads: int, sigma: float)
Wrapper method for accessing GPU Multi-blank RNNT loss.
This function computes the RNNT loss for models that utilize a multi-blank approach, as described in the paper: https://arxiv.org/pdf/2211.03541.pdf. It is a CUDA implementation ported from the Warp Transducer framework.
- Parameters:
- acts – Activation tensor of shape [B, T, U, V + num_big_blanks + 1].
- labels – Ground truth labels of shape [B, U].
- input_lengths – Lengths of the acoustic sequence as a vector of ints [B].
- label_lengths – Lengths of the target sequence as a vector of ints [B].
- costs – Zero vector of length [B] in which costs will be set.
- grads – Zero tensor of shape [B, T, U, V + num_big_blanks + 1] where the gradient will be set.
- blank_label – Index of the standard blank token in the vocabulary.
- big_blank_durations – A list of supported durations for big blank symbols in the model, e.g. [2, 4, 8]. This should not include 1 for the standard blank.
- fastemit_lambda – Float scaling factor for FastEmit regularization. Refer to the FastEmit paper for more details.
- clamp – Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
- num_threads – Number of threads for OpenMP.
- sigma – Logit-undernormalization weight used in the multi-blank model. Refer to the multi-blank paper for detailed explanations.
- Returns: Returns True if the computation was successful.
- Return type: bool
- Raises:
- RuntimeError – If there is an issue with memory allocation or if the
- RNNT status is not successful during calculations. –
Examples
>>> acts = torch.randn(8, 50, 20, 10) # Example activation tensor
>>> labels = torch.randint(0, 10, (8, 20)) # Example labels
>>> input_lengths = torch.randint(1, 50, (8,))
>>> label_lengths = torch.randint(1, 20, (8,))
>>> costs = torch.zeros(8)
>>> grads = torch.zeros(8, 50, 20, 10)
>>> big_blank_durations = [2, 4, 8]
>>> result = multiblank_rnnt_loss_gpu(
... acts, labels, input_lengths, label_lengths,
... costs, grads, blank_label=0,
... big_blank_durations=big_blank_durations,
... fastemit_lambda=0.1, clamp=0.1,
... num_threads=4, sigma=0.5
... )
>>> print(result) # Should print True if successful