espnet2.asr_transducer.utils.get_transducer_task_io
espnet2.asr_transducer.utils.get_transducer_task_io
espnet2.asr_transducer.utils.get_transducer_task_io(labels: Tensor, encoder_out_lens: Tensor, ignore_id: int = -1, blank_id: int = 0) → Tuple[Tensor, Tensor, Tensor, Tensor]
Get Transducer loss I/O.
This function prepares the input and target sequences required for calculating the Transducer loss. It processes the provided label sequences and encoder output lengths, handling padding and blank symbol insertion as necessary.
Parameters:
- labels – Label ID sequences. Shape: (B, L) where B is the batch size and L is the maximum label length.
- encoder_out_lens – Encoder output lengths. Shape: (B,) indicating the length of the encoder output for each sequence.
- ignore_id – Padding symbol ID, which will be ignored in the labels. Default is -1.
- blank_id – Blank symbol ID, which is prepended to the decoder input. Default is 0.
Returns: Decoder inputs. Shape: (B, U) where U is the maximum : number of tokens after prepending the blank symbol.
target: Target label ID sequences. Shape: (B, U) where U is the : number of valid tokens in each batch after ignoring the padding symbols.
t_len: Time lengths of the encoder outputs. Shape: (B,) where each : entry corresponds to the length of the encoder output for the respective input sequence.
u_len: Lengths of the target label sequences. Shape: (B,) where each : entry corresponds to the number of valid tokens in the target sequence.
Return type: decoder_in
Examples
>>> labels = torch.tensor([[1, 2, 3, -1], [1, -1, -1, -1]])
>>> encoder_out_lens = torch.tensor([4, 1])
>>> decoder_in, target, t_len, u_len = get_transducer_task_io(labels,
... encoder_out_lens)
>>> print(decoder_in)
tensor([[0, 1, 2, 3],
[0, 1]])
>>> print(target)
tensor([[1, 2, 3],
[1]])
>>> print(t_len)
tensor([4, 1])
>>> print(u_len)
tensor([3, 1])
NOTE
The function assumes that the input tensors are on the same device (CPU or GPU).
- Raises:ValueError – If the input tensors have incompatible shapes or if there are invalid IDs in the label sequences.