espnet2.asr.state_spaces.pool.DownSpectralPool
espnet2.asr.state_spaces.pool.DownSpectralPool
class espnet2.asr.state_spaces.pool.DownSpectralPool(d_input, stride=1, expand=1, transposed=True)
Bases: SequenceModule
Downsampling using spectral pooling.
This class implements downsampling on sequences using spectral methods. It performs an inverse Fast Fourier Transform (iFFT) to convert the input to the frequency domain, selects the relevant frequency components based on the specified stride, and then applies an inverse iFFT to obtain the downsampled output.
d_input
The dimensionality of the input features.
- Type: int
stride
The factor by which to downsample the input.
- Type: int
expand
The factor by which to expand the output features.
- Type: int
transposed
Whether to perform the operation in transposed mode.
Type: bool
Parameters:
- d_input (int) – The input feature dimension.
- stride (int , optional) – The downsampling factor. Defaults to 1.
- expand (int , optional) – The expansion factor for the output. Defaults to 1.
- transposed (bool , optional) – Whether to use transposed operations. Defaults to True.
Returns: The downsampled output tensor.
Return type: Tensor
Raises:AssertionError – If the input length is not divisible by the stride.
########### Examples
>>> import torch
>>> pool = DownSpectralPool(d_input=64, stride=2, expand=1)
>>> input_tensor = torch.randn(10, 5, 64) # (B, L, D)
>>> output_tensor = pool(input_tensor)
>>> output_tensor.shape
torch.Size([10, 3, 64]) # Downsampled length due to stride
####### NOTE This pooling method is particularly effective for frequency-based downsampling and may not be suitable for all types of data.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
property d_output
Calculates the output dimension based on input dimension and expand.
The output dimension is determined by multiplying the input dimension (d_input) by the expand factor. This property is useful for understanding the dimensionality of the output tensor after applying downsampling operations.
- Returns: The computed output dimension.
- Return type: int
########### Examples
>>> downsample_layer = DownSample(d_input=128, expand=2)
>>> downsample_layer.d_output
256
forward(x)
Implements downsampling of sequences using spectral pooling.
This class applies downsampling through spectral pooling, which uses the Fast Fourier Transform (FFT) to manipulate the frequency domain representation of the input data. It supports optional transposed operations for upsampling.
d_input
The input dimensionality of the data.
- Type: int
stride
The factor by which to downsample the input.
- Type: int
expand
The factor by which to expand the output.
- Type: int
transposed
Indicates whether the operation is transposed.
Type: bool
Parameters:
- d_input (int) – Input dimensionality.
- stride (int) – The downsampling factor (default: 1).
- expand (int) – The expansion factor for the output (default: 1).
- transposed (bool) – Whether to perform transposed operations (default: True).
Returns: The downsampled output tensor of shape (B, D’, …) where D’ = d_input * expand.
Return type: torch.Tensor
Raises:AssertionError – If the input length is not divisible by stride.
########### Examples
>>> import torch
>>> down_pool = DownSpectralPool(d_input=64, stride=2, expand=1)
>>> x = torch.randn(8, 10, 64) # Batch of 8, sequence length of 10
>>> output = down_pool(x)
>>> output.shape
torch.Size([8, 32, ...]) # Output shape will depend on input dimensions
####### NOTE This method requires the input length to be divisible by the stride to ensure valid downsampling.
step(x, state, **kwargs)
Applies downsampling using spectral methods on input sequences.
This class implements downsampling of input sequences using spectral methods. It transforms the input into the frequency domain, selects indices based on the specified stride, and transforms the result back into the time domain. The downsampling can also include an expansion factor.
d_input
The dimensionality of the input.
- Type: int
stride
The downsampling factor.
- Type: int
expand
The expansion factor for the output.
- Type: int
transposed
Indicates whether to use transposed operations.
Type: bool
Parameters:
- d_input (int) – Dimensionality of the input.
- stride (int) – Downsampling factor (default is 1).
- expand (int) – Expansion factor for the output (default is 1).
- transposed (bool) – If True, use transposed operations (default is True).
Returns: The downsampled output tensor.
Return type: Tensor
Raises:NotImplementedError – If stride or expand is greater than 1.
########### Examples
>>> down_pool = DownSpectralPool(d_input=128, stride=2, expand=1)
>>> x = torch.randn(10, 20, 128) # Batch size 10, sequence length 20
>>> output = down_pool(x)
>>> print(output.shape) # Output shape will be (10, 10, 128)
####### NOTE The input length must be divisible by the stride.