espnet2.asr.state_spaces.pool.DownAvgPool
espnet2.asr.state_spaces.pool.DownAvgPool
class espnet2.asr.state_spaces.pool.DownAvgPool(d_input, stride=1, expand=1, transposed=True)
Bases: SequenceModule
Downsample input sequences using average pooling.
This module applies average pooling to the input tensor along the layer dimension while allowing for expansion on the feature dimension. It can operate in both transposed and non-transposed modes.
d_input
The number of input features.
- Type: int
stride
The downsampling factor for the layer dimension.
- Type: int
expand
The repetition factor for the feature dimension.
- Type: int
transposed
If True, the operation will be applied in transposed mode.
Type: bool
Parameters:
- d_input (int) – Number of input features.
- stride (int , optional) – Downsampling factor. Default is 1.
- expand (int , optional) – Feature dimension expansion factor. Default is 1.
- transposed (bool , optional) – If True, operates in transposed mode. Default is True.
Returns: The downsampled (and potentially expanded) output tensor.
Return type: Tensor
Raises:NotImplementedError – If stride or expand is greater than 1 in the step method.
########### Examples
>>> down_avg_pool = DownAvgPool(d_input=64, stride=2, expand=2)
>>> input_tensor = torch.randn(10, 20, 64) # (batch_size, length, features)
>>> output_tensor = down_avg_pool(input_tensor)
>>> print(output_tensor.shape)
torch.Size([10, 10, 128]) # Downsampled length and expanded features
####### NOTE This module expects the input tensor to have a minimum of 3 dimensions.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
property d_output
Calculates the output dimension after applying the downsampling.
This property computes the output dimension based on the input dimension and the expand factor. The output dimension is given by the formula: d_output = d_input * expand.
- Returns: The computed output dimension.
- Return type: int
########### Examples
>>> down_sample = DownSample(d_input=128, expand=2)
>>> down_sample.d_output
256
forward(x)
Perform the forward pass of the DownAvgPool layer.
This method takes an input tensor x and applies average pooling and optional expansion based on the specified stride and expand parameters.
- Parameters:x (torch.Tensor) – Input tensor of shape (B, L…, D) where:
- B is the batch size,
- L… represents any number of additional dimensions,
- D is the number of features.
- Returns: Output tensor after applying average pooling and expansion, of shape (B, L’…, D’), where L’ is the reduced length after pooling and D’ is the expanded number of features if expand > 1.
- Return type: torch.Tensor
########### Examples
>>> pool = DownAvgPool(d_input=16, stride=2, expand=2)
>>> input_tensor = torch.randn(8, 4, 16) # Batch of 8, 4 time steps, 16 features
>>> output_tensor = pool(input_tensor)
>>> output_tensor.shape
torch.Size([8, 2, 32]) # Output shape after pooling and expansion
####### NOTE
- If self.transposed is set to True, the input tensor will be rearranged before applying the pooling operation.
- The method raises a NotImplementedError if the stride or expand parameters are greater than 1 in the step method.
- Raises:
- NotImplementedError – If stride or expand parameters are greater than 1
- during the step method. –
step(x, state, **kwargs)
Process a single time step in a recurrent model.
This method handles the input tensor x and the current state, performing operations defined by the pooling parameters. It is primarily used in recurrent models where state is maintained across time steps.
- Parameters:
- x (torch.Tensor) – The input tensor of shape (…, H), where H represents the feature dimension.
- state (list) – A list representing the current state of the model. It is updated during the step.
- **kwargs – Additional keyword arguments for future extensions.
- Returns: A tuple containing: : - torch.Tensor or None: The processed output tensor if <br/> the state has reached the defined stride, otherwise None.
- list: The updated state after processing the input.
- Return type: tuple
- Raises:
- NotImplementedError – If the stride or expand attributes are
- greater than 1**,** as these operations are not implemented in this –
- method. –
########### Examples
>>> model = DownSample(d_input=64, stride=1, expand=1)
>>> state = []
>>> output, new_state = model.step(torch.randn(10, 64), state)
>>> assert output.shape == (10, 64) # Assuming the input has
>>> # the correct dimensions for processing.
####### NOTE This method is intended to be used in a recurrent context, and is not designed for batch processing of inputs.