espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalInstanceNorm2d
espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalInstanceNorm2d
class espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalInstanceNorm2d(num_features, num_classes, bias=True)
Bases: Module
Applies Conditional Instance Normalization over a 4D input.
ConditionalInstanceNorm2d normalizes the input tensor based on the instance statistics and the provided class condition. It uses an embedding layer to learn the scaling and shifting parameters conditioned on the class label.
num_features
Number of features (channels) in the input.
- Type: int
bias
If True, adds learnable bias to the output.
- Type: bool
instance_norm
Instance normalization layer.
- Type: nn.InstanceNorm2d
embed
Embedding layer for class conditioning.
Type: nn.Embedding
Parameters:
- num_features (int) – Number of features (channels) in the input.
- num_classes (int) – Number of classes for the conditional embedding.
- bias (bool , optional) – If True, adds learnable bias. Defaults to True.
Returns: The output tensor after applying conditional instance normalization.
Return type: Tensor
####### Examples
>>> import torch
>>> norm = ConditionalInstanceNorm2d(num_features=3, num_classes=10)
>>> x = torch.randn(5, 3, 32, 32) # Batch of 5, 3 channels, 32x32
>>> y = torch.randint(0, 10, (5,)) # Random class labels for batch
>>> output = norm(x, y)
>>> output.shape
torch.Size([5, 3, 32, 32])
NOTE
The input tensor should have the shape (N, C, H, W), where: N is the batch size, C is the number of channels, H is the height, and W is the width.
- Raises:RuntimeError – If the number of classes is less than or equal to zero.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, y)
Applies Conditional Instance Normalization to the input tensor.
This method performs instance normalization on the input tensor x conditioned on the class labels y. If bias is enabled, it also applies learned scaling (gamma) and shifting (beta) parameters based on the class labels.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height, and W is the width.
- y (torch.Tensor) – Class labels of shape (N,) where each entry is an integer representing the class index corresponding to the input tensor.
- Returns: The output tensor after applying conditional instance normalization. The output has the same shape as the input tensor.
- Return type: torch.Tensor
####### Examples
>>> model = ConditionalInstanceNorm2d(num_features=64, num_classes=10)
>>> x = torch.randn(8, 64, 32, 32) # Batch of 8 images
>>> y = torch.randint(0, 10, (8,)) # Random class labels
>>> output = model(x, y)
>>> print(output.shape)
torch.Size([8, 64, 32, 32]) # Output shape matches input shape