espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils.gpu_rnnt_kernel.compute_multiblank_alphas_kernel
espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils.gpu_rnnt_kernel.compute_multiblank_alphas_kernel
espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils.gpu_rnnt_kernel.compute_multiblank_alphas_kernel(acts: Tensor, denom: Tensor, sigma: float, alphas: Tensor, llForward: Tensor, xlen: Tensor, ylen: Tensor, mlabels: Tensor, minibatch: int, maxT: int, maxU: int, alphabet_size: int, blank_: int, big_blank_duration: Tensor, num_big_blanks: int)
Compute alpha (forward variable) probabilities for multi-blank transducer loss (https://arxiv.org/pdf/2211.03541).
This kernel computes the forward variable probabilities for a multi-blank transducer model, incorporating logit under-normalization to stabilize training with blank tokens of variable duration.
- Parameters:
- acts – Tensor of shape [B, T, U, V + 1 + num_big_blanks] flattened. Represents the log probabilities activation tensor.
- denom – Tensor of shape [B, T, U] flattened. Represents the denominator of the log probabilities activation tensor across the entire vocabulary.
- sigma – Hyper-parameter for logit under-normalization technique for training multi-blank transducers.
- alphas – Zero tensor of shape [B, T, U]. Will be updated inside the kernel with the forward variable probabilities.
- llForward – Zero tensor of shape [B]. Represents the log-likelihood of the forward pass. Returned as the forward pass loss that is reduced by the optimizer.
- xlen – Vector of length B which contains the actual acoustic sequence lengths in the padded activation tensor.
- ylen – Vector of length B which contains the actual target sequence lengths in the padded activation tensor.
- mlabels – Matrix of shape [B, U+1] (+1 here is due to <SOS> token
- usually the RNNT blank). The matrix contains the padded target transcription that must be predicted.
- minibatch – Int representing the batch size.
- maxT – The maximum possible acoustic sequence length. Represents T in the log probabilities tensor.
- maxU – The maximum possible target sequence length. Represents U in the log probabilities tensor.
- alphabet_size – The vocabulary dimension V+1 (inclusive of RNNT blank).
- blank – Index of the RNNT standard blank token in the vocabulary.
- big_blank_duration – Vector of supported big blank durations of the model.
- num_big_blanks – Number of big blanks of the model.
Updates: : Kernel in-place updates the following inputs:
- alphas: forward variable scores.
- llForward: log-likelihood of forward variable.
Examples
>>> acts = torch.rand((2, 5, 3, 10)) # Random activations
>>> denom = torch.rand((2, 5, 3)) # Random denominators
>>> sigma = 0.5
>>> alphas = torch.zeros((2, 5, 3))
>>> llForward = torch.zeros(2)
>>> xlen = torch.tensor([5, 5])
>>> ylen = torch.tensor([2, 2])
>>> mlabels = torch.randint(0, 9, (2, 3))
>>> minibatch = 2
>>> maxT = 5
>>> maxU = 3
>>> alphabet_size = 10
>>> blank_ = 0
>>> big_blank_duration = torch.tensor([1, 2])
>>> num_big_blanks = 2
>>> compute_multiblank_alphas_kernel(acts, denom, sigma, alphas,
... llForward, xlen, ylen,
... mlabels, minibatch, maxT,
... maxU, alphabet_size, blank_,
... big_blank_duration, num_big_blanks)
NOTE
This kernel is designed to be executed on a CUDA device and requires proper configuration for grid and block sizes when launched.