espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.certify_inputs
Less than 1 minute
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.certify_inputs
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank.certify_inputs(log_probs, labels, lengths, label_lengths)
Validate input tensors for RNNT loss computations.
This function checks the types, contiguity, and dimensions of the input tensors used in RNNT loss calculations. It ensures that the shapes and data types of the inputs are compatible with the expected requirements.
- Parameters:
- log_probs (torch.Tensor) – Tensor of shape (batch, seqLength, labelLength, outputDim) containing the log probabilities output from the network.
- labels (torch.Tensor) – 2D tensor of shape (batch, labelLength) containing the target labels for each example in the batch, zero-padded.
- lengths (torch.Tensor) – 1D tensor of shape (batch) containing the actual lengths of each output sequence from the network.
- label_lengths (torch.Tensor) – 1D tensor of shape (batch) containing the length of each target label for the corresponding example.
- Raises:
- TypeError – If any of the input tensors have incorrect data types.
- ValueError – If any of the input tensors have incorrect shapes or are not contiguous in memory.
NOTE
The log_probs tensor is expected to be of type float32, while labels, lengths, and label_lengths should be of type int32.
Examples
>>> log_probs = torch.randn(2, 10, 5, 20).float() # Example log probs
>>> labels = torch.tensor([[1, 2, 3], [0, 1, 2]], dtype=torch.int32)
>>> lengths = torch.tensor([10, 8], dtype=torch.int32)
>>> label_lengths = torch.tensor([3, 2], dtype=torch.int32)
>>> certify_inputs(log_probs, labels, lengths, label_lengths)