espnet2.legacy.nets.pytorch_backend.nets_utils.mask_by_length
Less than 1 minute
espnet2.legacy.nets.pytorch_backend.nets_utils.mask_by_length
espnet2.legacy.nets.pytorch_backend.nets_utils.mask_by_length(xs, lengths, fill=0)
Mask tensor according to length.
- Parameters:
- xs (Tensor) β Batch of input tensor (B, *).
- lengths (LongTensor or List) β Batch of lengths (B,).
- fill (int or float) β Value to fill masked part.
- Returns: Batch of masked input tensor (B, *).
- Return type: Tensor
Examples
>>> x = torch.arange(5).repeat(3, 1) + 1
>>> x
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]])
>>> lengths = [5, 3, 2]
>>> mask_by_length(x, lengths)
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 0, 0],
[1, 2, 0, 0, 0]])