espnet2.enh.layers.ncsnpp_utils.layers.Downsample
espnet2.enh.layers.ncsnpp_utils.layers.Downsample
class espnet2.enh.layers.ncsnpp_utils.layers.Downsample(channels, with_conv=False)
Bases: Module
Downsample the input tensor using average pooling or convolution.
This class implements a downsampling layer that reduces the spatial dimensions of the input tensor. It can perform downsampling using average pooling or a convolutional layer, depending on the with_conv parameter. When using a convolutional layer, ‘SAME’ padding is applied to ensure the output dimensions are halved.
Conv_0
A 3x3 convolutional layer used for downsampling if with_conv is set to True.
- Type: nn.Module
with_conv
Indicates whether to use convolution for downsampling or average pooling.
Type: bool
Parameters:
- channels (int) – Number of input channels.
- with_conv (bool) – If True, uses convolution for downsampling. Defaults to False.
Returns: The downsampled tensor.
Return type: torch.Tensor
####### Examples
>>> downsample_layer = Downsample(channels=64, with_conv=True)
>>> input_tensor = torch.randn(1, 64, 128, 128) # (B, C, H, W)
>>> output_tensor = downsample_layer(input_tensor)
>>> print(output_tensor.shape) # Should be (1, 64, 64, 64)
NOTE
If with_conv is set to True, the input tensor will be padded before applying the convolution. This ensures that the output size is consistent with the downsampling operation.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Downsample the input tensor using average pooling or convolution.
This module downsamples the input tensor either by using average pooling or by applying a 3x3 convolution with a stride of 2. The option to perform convolution can be enabled during initialization.
with_conv
A flag to determine whether to use convolution for downsampling.
Type: bool
Parameters:
- channels (int) – The number of input channels.
- with_conv (bool) – If True, uses convolution for downsampling. Defaults to False.
Returns: The downsampled tensor of shape (B, C, H//2, W//2), where B is : the batch size, C is the number of channels, H is the height, and W is the width of the input tensor.
Return type: Tensor
####### Examples
>>> downsample_layer = Downsample(channels=64, with_conv=True)
>>> input_tensor = torch.randn(1, 64, 32, 32)
>>> output_tensor = downsample_layer(input_tensor)
>>> output_tensor.shape
torch.Size([1, 64, 16, 16])
NOTE
When using convolution for downsampling, the input tensor will be padded to emulate ‘SAME’ padding, ensuring that the output tensor has the correct spatial dimensions.
- Raises:AssertionError – If the shape of the output tensor does not match the expected shape (B, C, H//2, W//2).