espnet2.enh.layers.ncsnpp_utils.layerspp.ResnetBlockBigGANpp
espnet2.enh.layers.ncsnpp_utils.layerspp.ResnetBlockBigGANpp
class espnet2.enh.layers.ncsnpp_utils.layerspp.ResnetBlockBigGANpp(act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), skip_rescale=True, init_scale=0.0)
Bases: Module
ResNet Block adapted for BigGAN++ architecture.
This class implements a residual block that can optionally perform upsampling or downsampling operations. It uses group normalization, dropout, and convolutional layers to transform the input tensor while maintaining residual connections.
act
The activation function to apply.
in_ch
Number of input channels.
out_ch
Number of output channels.
temb_dim
Dimension of the time embedding (optional).
up
Boolean indicating if upsampling should be performed.
down
Boolean indicating if downsampling should be performed.
dropout
Dropout rate to apply.
fir
Boolean indicating if a FIR filter should be used for sampling.
fir
Kernel to use for the FIR filter.
skip_rescale
Boolean indicating if the output should be rescaled.
init_scale
Initial scale for weights.
- Parameters:
- act (callable) – Activation function.
- in_ch (int) – Number of input channels.
- out_ch (int , optional) – Number of output channels. Defaults to in_ch.
- temb_dim (int , optional) – Dimension of the time embedding. Defaults to None.
- up (bool , optional) – If True, performs upsampling. Defaults to False.
- down (bool , optional) – If True, performs downsampling. Defaults to False.
- dropout (float , optional) – Dropout rate. Defaults to 0.1.
- fir (bool , optional) – If True, uses FIR filter for upsampling/downsampling. Defaults to False.
- fir_kernel (tuple , optional) – Kernel for FIR filter. Defaults to (1, 3, 3, 1).
- skip_rescale (bool , optional) – If True, rescales the output. Defaults to True.
- init_scale (float , optional) – Initial scale for the weights. Defaults to 0.0.
- Returns: The output tensor after applying the residual block.
- Return type: torch.Tensor
####### Examples
>>> block = ResnetBlockBigGANpp(act=F.relu, in_ch=64, out_ch=128, up=True)
>>> input_tensor = torch.randn(1, 64, 32, 32)
>>> output_tensor = block(input_tensor)
>>> print(output_tensor.shape)
torch.Size([1, 128, 64, 64])
NOTE
Ensure that the input tensor shape matches the expected dimensions based on the in_ch parameter. The output tensor will have the shape of (batch_size, out_ch, height, width) depending on the upsampling or downsampling operations performed.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, temb=None)
Perform a forward pass through the ResNet block.
This method takes input tensor x and optional time embedding temb, processes them through several layers including normalization, convolution, and activation functions, and applies the appropriate upsampling or downsampling if specified during initialization.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (B, C, H, W) where B is the batch size, C is the number of channels, H is the height, and W is the width.
- temb (torch.Tensor , optional) – Time embedding tensor of shape (B, T) where T is the embedding dimension. Default is None.
- Returns: Output tensor of shape (B, out_ch, H’, W’) where : H’ and W’ depend on whether upsampling or downsampling is performed.
- Return type: torch.Tensor
####### Examples
>>> block = ResnetBlockBigGANpp(act=nn.ReLU(), in_ch=64, out_ch=128)
>>> x = torch.randn(8, 64, 32, 32) # Batch of 8, 64 channels, 32x32
>>> output = block(x)
>>> output.shape
torch.Size([8, 128, 32, 32]) # Output shape
NOTE
If up or down is set to True during initialization, the input tensor will be upsampled or downsampled accordingly.
- Raises:ValueError – If the input dimensions do not match the expected shape.