espnet2.enh.layers.ncsnpp_utils.up_or_down_sampling.conv_downsample_2d
Less than 1 minute
espnet2.enh.layers.ncsnpp_utils.up_or_down_sampling.conv_downsample_2d
espnet2.enh.layers.ncsnpp_utils.up_or_down_sampling.conv_downsample_2d(x, w, k=None, factor=2, gain=1)
Fused tf.nn.conv2d() followed by downsample_2d().
This function performs a combined operation of a 2D convolution followed by downsampling. It applies the convolution with the provided weights and then downsamples the result using the specified FIR filter. The padding is performed only once at the beginning, making this approach more efficient than performing the operations separately.
- Parameters:
- x – Input tensor of shape [N, C, H, W] or [N, H, W, C], where:
- N is the batch size,
- C is the number of channels,
- H is the height,
- W is the width of the input tensor.
- w – Weight tensor of shape [filterH, filterW, inChannels, outChannels]. Grouped convolution can be performed by setting inChannels = x.shape[0] // numGroups.
- k – FIR filter of shape [firH, firW] or [firN] (separable). The default is [1] * factor, which corresponds to average pooling.
- factor – Integer downsampling factor (default: 2).
- gain – Scaling factor for signal magnitude (default: 1.0).
- x – Input tensor of shape [N, C, H, W] or [N, H, W, C], where:
- Returns: Tensor of shape [N, C, H // factor, W // factor] or [N, H // factor, W // factor, C], with the same datatype as x.
Examples
>>> x = torch.randn(1, 3, 64, 64) # Example input
>>> w = torch.randn(3, 3, 3, 3) # Example weights
>>> output = conv_downsample_2d(x, w, factor=2)
>>> print(output.shape)
torch.Size([1, 3, 32, 32]) # Output shape after downsampling