espnet2.enh.layers.ncsnpp_utils.layers.ResidualBlock
espnet2.enh.layers.ncsnpp_utils.layers.ResidualBlock
class espnet2.enh.layers.ncsnpp_utils.layers.ResidualBlock(input_dim, output_dim, resample=None, act=ELU(alpha=1.0), normalization=<class 'torch.nn.modules.instancenorm.InstanceNorm2d'>, adjust_padding=False, dilation=1)
Bases: Module
Residual Block for deep learning architectures.
This class implements a residual block, which consists of two convolutional layers with a skip connection. The block can perform downsampling and apply normalization and activation functions as specified in the initialization.
non_linearity
The activation function to apply.
- Type: callable
input_dim
The number of input channels.
- Type: int
output_dim
The number of output channels.
- Type: int
resample
If ‘down’, the block will downsample the input.
- Type: str or None
normalization
The normalization layer to use.
- Type: callable
shortcut
The shortcut connection layer.
- Type: nn.Module
normalize1
The normalization layer applied before the first convolution.
- Type: callable
normalize2
The normalization layer applied after the first convolution.
Type: callable
Parameters:
- input_dim (int) – Number of input channels.
- output_dim (int) – Number of output channels.
- resample (str , optional) – ‘down’ for downsampling, None for no change.
- act (callable , optional) – Activation function, default is nn.ELU().
- normalization (callable , optional) – Normalization layer, default is nn.InstanceNorm2d.
- adjust_padding (bool , optional) – If True, adjusts padding for convolutions.
- dilation (int , optional) – Dilation rate for convolutions, default is 1.
Returns: The output tensor after applying the residual block.
Return type: Tensor
Raises:Exception – If an invalid resample value is provided.
####### Examples
>>> block = ResidualBlock(input_dim=64, output_dim=128, resample='down')
>>> input_tensor = torch.randn(1, 64, 32, 32)
>>> output_tensor = block(input_tensor)
>>> output_tensor.shape
torch.Size([1, 128, 16, 16])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Perform the forward pass of the Residual Block.
This method applies the forward operations of the Residual Block, which includes normalization, non-linearity, convolutional operations, and adding the shortcut connection.
- Parameters:x (torch.Tensor) – Input tensor of shape (B, C, H, W), where: B = batch size, C = number of input channels, H = height of the input feature map, W = width of the input feature map.
- Returns: Output tensor of the same shape as the input tensor.
- Return type: torch.Tensor
NOTE
The output of the block is computed as the sum of the input tensor and the processed tensor through the block. If the dimensions of the input and output tensors do not match, a convolution is applied to the input tensor to adjust its dimensions before summation.
####### Examples
>>> block = ResidualBlock(input_dim=64, output_dim=128)
>>> input_tensor = torch.randn(32, 64, 128, 128)
>>> output_tensor = block(input_tensor)
>>> print(output_tensor.shape)
torch.Size([32, 128, 128, 128])