espnet2.asr.transducer.rnnt_multi_blank.rnnt.rnnt_loss_cpu
About 1 min
espnet2.asr.transducer.rnnt_multi_blank.rnnt.rnnt_loss_cpu
espnet2.asr.transducer.rnnt_multi_blank.rnnt.rnnt_loss_cpu(acts: Tensor, labels: Tensor, input_lengths: Tensor, label_lengths: Tensor, costs: Tensor, grads: Tensor, blank_label: int, fastemit_lambda: float, clamp: float, num_threads: int)
Calculate the RNNT loss using a CPU implementation.
This function serves as a wrapper to compute the RNNT loss for a given set of activations and labels on the CPU. The implementation is based on the work by HawkAaron in the warp-transducer repository.
- Parameters:
- acts – Activation tensor of shape [B, T, U, V+1], where B is the batch size, T is the time dimension, U is the target sequence length, and V is the vocabulary size.
- labels – Ground truth labels of shape [B, U], where U is the target sequence length.
- input_lengths – A tensor of shape [B] representing the lengths of the acoustic sequences.
- label_lengths – A tensor of shape [B] representing the lengths of the target sequences.
- costs – A tensor of shape [B] initialized to zero, where the computed costs will be stored.
- grads – A tensor of shape [B, T, U, V+1] initialized to zero, where the computed gradients will be stored.
- blank_label – An integer indicating the index of the blank token in the vocabulary.
- fastemit_lambda – A float scaling factor for FastEmit regularization, which can improve the efficiency of streaming ASR.
- clamp – A float value that, when set to a value >= 0.0, will clamp the gradients to the range [-clamp, clamp].
- num_threads – An integer specifying the number of threads to use for OpenMP parallelization.
- Returns: Returns True if the computation was successful.
- Return type: bool
- Raises:
- RuntimeError – If the working space memory allocation fails or if
- the RNNT status indicates an error during computation. –
Examples
>>> acts = torch.randn(32, 100, 20, 30) # Example activation tensor
>>> labels = torch.randint(0, 30, (32, 20)) # Example labels
>>> input_lengths = torch.randint(1, 100, (32,)) # Example lengths
>>> label_lengths = torch.randint(1, 20, (32,)) # Example lengths
>>> costs = torch.zeros(32) # Initialize costs
>>> grads = torch.zeros(32, 100, 20, 30) # Initialize grads
>>> blank_label = 0
>>> fastemit_lambda = 0.1
>>> clamp = 1.0
>>> num_threads = 4
>>> rnnt_loss_cpu(acts, labels, input_lengths, label_lengths,
... costs, grads, blank_label, fastemit_lambda,
... clamp, num_threads)