espnet2.enh.layers.ncsnpp_utils.layers.ConditionalResidualBlock
espnet2.enh.layers.ncsnpp_utils.layers.ConditionalResidualBlock
class espnet2.enh.layers.ncsnpp_utils.layers.ConditionalResidualBlock(input_dim, output_dim, num_classes, resample=1, act=ELU(alpha=1.0), normalization=<class 'espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalInstanceNorm2dPlus'>, adjust_padding=False, dilation=None)
Bases: Module
Conditional Residual Block for neural network architectures.
This block implements a residual connection with conditional normalization. It can be used in various architectures where residual connections and conditioning based on class labels are needed. The block supports downsampling, dilation, and different types of activation functions.
non_linearity
Activation function to apply.
- Type: callable
input_dim
Number of input channels.
- Type: int
output_dim
Number of output channels.
- Type: int
resample
Resampling method. Can be ‘down’, None, or an int.
- Type: Union[int, None]
normalization
Normalization method to use for the layers.
- Type: callable
conv1
First convolutional layer.
- Type: nn.Module
normalize1
First normalization layer.
- Type: nn.Module
normalize2
Second normalization layer.
- Type: nn.Module
conv2
Second convolutional layer.
- Type: nn.Module
shortcut
Shortcut connection layer if needed.
Type: nn.Module
Parameters:
- input_dim (int) – Number of input channels.
- output_dim (int) – Number of output channels.
- num_classes (int) – Number of classes for conditioning.
- resample (Union *[*str , None ]) – Resampling method (‘down’, None).
- act (callable) – Activation function (default: nn.ELU()).
- normalization (callable) – Normalization layer (default: ConditionalInstanceNorm2dPlus).
- adjust_padding (bool) – Whether to adjust padding (default: False).
- dilation (int) – Dilation rate for convolutions (default: None).
Returns: The output tensor after applying the block.
Return type: Tensor
####### Examples
>>> block = ConditionalResidualBlock(input_dim=64, output_dim=128,
... num_classes=10, resample='down')
>>> x = torch.randn(16, 64, 32, 32) # Batch of 16, 64 channels, 32x32 size
>>> y = torch.randint(0, 10, (16,)) # Random class labels for conditioning
>>> output = block(x, y)
>>> output.shape
torch.Size([16, 128, 16, 16]) # Output shape after downsampling
NOTE
The block assumes that the input tensor is in the format (N, C, H, W), where N is the batch size, C is the number of channels, and H and W are the height and width of the input feature maps.
- Raises:Exception – If an invalid resample value is provided.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, y)
Conditional Residual Block for neural network architectures.
This block implements a residual connection with conditional normalization, which is useful in generative models, particularly in diffusion models. The block can perform downsampling, maintain spatial dimensions, or upsample based on the specified configuration. It uses activation functions and normalization layers to enhance learning and representation.
non_linearity
The activation function to apply.
- Type: nn.Module
input_dim
The number of input channels.
- Type: int
output_dim
The number of output channels.
- Type: int
resample
The resampling strategy.
- Type: Union[int, str, None]
normalization
The normalization layer to apply.
- Type: callable
conv1
The first convolutional layer.
- Type: nn.Module
normalize1
The first normalization layer.
- Type: nn.Module
normalize2
The second normalization layer.
- Type: nn.Module
conv2
The second convolutional layer.
- Type: nn.Module
shortcut
The shortcut connection for residual learning.
Type: nn.Module
Parameters:
- input_dim (int) – Number of input channels.
- output_dim (int) – Number of output channels.
- num_classes (int) – Number of classes for conditional normalization.
- resample (Union *[*int , str , None ]) – Resampling strategy (‘down’, None).
- act (nn.Module) – Activation function to use (default: nn.ELU()).
- normalization (callable) – Normalization layer to use (default: ConditionalInstanceNorm2dPlus).
- adjust_padding (bool) – Whether to adjust padding (default: False).
- dilation (Optional *[*int ]) – Dilation rate for convolution (default: None).
Returns: The output tensor after applying the residual block.
Return type: torch.Tensor
####### Examples
>>> block = ConditionalResidualBlock(input_dim=64, output_dim=128,
... num_classes=10, resample='down')
>>> x = torch.randn(8, 64, 32, 32) # Batch of 8 images
>>> y = torch.randint(0, 10, (8,)) # Batch of 8 class labels
>>> output = block(x, y)
>>> print(output.shape) # Output shape should be (8, 128, 16, 16)
- Raises:Exception – If an invalid resample value is provided.