espnet2.asr.state_spaces.pool.downsample
Less than 1 minute
espnet2.asr.state_spaces.pool.downsample
espnet2.asr.state_spaces.pool.downsample(x, stride=1, expand=1, transposed=False)
Downsample or upsample a sequence tensor.
This function performs downsampling and upsampling on input tensors. It allows for reducing the sequence length by a specified stride and expanding the feature dimension by a specified factor. The operation can be performed in a transposed manner as well.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (B, L, D) where B is the batch size, L is the sequence length, and D is the feature dimension.
- stride (int , optional) – The downsampling factor for the sequence length. Default is 1 (no downsampling).
- expand (int , optional) – The factor by which to expand the feature dimension. Default is 1 (no expansion).
- transposed (bool , optional) – If True, performs the operation in transposed mode. Default is False.
- Returns: The downsampled or upsampled tensor.
- Return type: torch.Tensor
- Raises:AssertionError – If the input tensor has more than 3 dimensions while a stride greater than 1 is specified.
Examples
>>> x = torch.randn(2, 8, 4) # Batch of 2, sequence length 8, features 4
>>> downsampled = downsample(x, stride=2, expand=2)
>>> downsampled.shape
torch.Size([2, 4, 8]) # Sequence length reduced to 4, features expanded to 8
>>> upsampled = downsample(x, stride=1, expand=2, transposed=True)
>>> upsampled.shape
torch.Size([2, 16, 4]) # Sequence length remains 8, features expanded to 16