espnet2.enh.layers.ncsnpp_utils.layers.Upsample
espnet2.enh.layers.ncsnpp_utils.layers.Upsample
class espnet2.enh.layers.ncsnpp_utils.layers.Upsample(channels, with_conv=False)
Bases: Module
Upsampling layer with optional convolution.
This class implements an upsampling operation using nearest neighbor interpolation, followed by an optional 3x3 convolution. It is designed to double the height and width of the input tensor.
Conv_0
A 3x3 convolution layer applied to the upsampled output if with_conv is set to True.
- Type: nn.Module, optional
with_conv
Indicates whether to apply a convolution after upsampling.
Type: bool
Parameters:
- channels (int) – The number of input and output channels for the convolution layer.
- with_conv (bool) – If True, a convolution is applied after the upsampling. Defaults to False.
Returns: The upsampled tensor, optionally processed by a convolution layer.
Return type: torch.Tensor
####### Examples
>>> upsample_layer = Upsample(channels=64, with_conv=True)
>>> input_tensor = torch.randn(1, 64, 32, 32) # Batch of 1, 64 channels
>>> output_tensor = upsample_layer(input_tensor)
>>> print(output_tensor.shape) # Output shape: (1, 64, 64, 64)
NOTE
The input tensor is expected to have a shape of (B, C, H, W), where B is the batch size, C is the number of channels, H is the height, and W is the width.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Upsample block for increasing the spatial dimensions of the input.
This module performs upsampling of the input tensor, optionally followed by a 3x3 convolution. It uses nearest neighbor interpolation for upsampling and can include a convolutional layer to refine the output.
Conv_0
Optional convolutional layer applied after upsampling.
- Type: nn.Conv2d
with_conv
Flag indicating whether to apply a convolution after upsampling.
Type: bool
Parameters:
- channels (int) – Number of input and output channels for the convolution.
- with_conv (bool) – If True, includes a convolutional layer after the upsampling.
Returns: The upsampled (and optionally convolved) output tensor.
Return type: Tensor
####### Examples
>>> upsample_layer = Upsample(channels=64, with_conv=True)
>>> input_tensor = torch.randn(1, 64, 32, 32) # Batch of 1, 64 channels
>>> output_tensor = upsample_layer(input_tensor)
>>> print(output_tensor.shape) # Output shape should be [1, 64, 64, 64]