espnet2.asr_transducer.utils.make_source_mask
Less than 1 minute
espnet2.asr_transducer.utils.make_source_mask
espnet2.asr_transducer.utils.make_source_mask(lengths: Tensor) → Tensor
Create source mask for given lengths.
This function generates a source mask for a batch of sequences based on their lengths. The mask is a binary tensor where each position is marked as True if it is valid (i.e., within the length of the corresponding sequence) and False otherwise.
Reference: : https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
- Parameters:lengths – Sequence lengths. (B,)
- Returns: Mask for the sequence lengths. (B, max_len)
- Return type: torch.Tensor
Examples
>>> lengths = torch.tensor([3, 5, 2])
>>> mask = make_source_mask(lengths)
>>> print(mask)
tensor([[ True, True, True, False, False],
[ True, True, True, True, True],
[ True, True, False, False, False]])