espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalInstanceNorm2dPlus
espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalInstanceNorm2dPlus
class espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalInstanceNorm2dPlus(num_features, num_classes, bias=True)
Bases: Module
Conditional Instance Normalization layer with additional scaling and bias.
This class implements a conditional instance normalization layer that normalizes input features using learned parameters based on the provided class labels. The layer includes the capability to apply an additional scaling factor and a bias term for more flexible output.
num_features
The number of input features (channels).
- Type: int
bias
Indicates whether to include a bias term in the layer.
- Type: bool
instance_norm
Instance normalization module.
- Type: nn.InstanceNorm2d
embed
Embedding layer for learning scaling and bias.
Type: nn.Embedding
Parameters:
- num_features (int) – Number of input features (channels).
- num_classes (int) – Number of classes for the conditional embedding.
- bias (bool) – If True, includes a bias term (default: True).
Returns: Normalized output tensor.
Return type: Tensor
####### Examples
>>> layer = ConditionalInstanceNorm2dPlus(num_features=64, num_classes=10)
>>> x = torch.randn(8, 64, 32, 32) # Batch of 8 images
>>> y = torch.randint(0, 10, (8,)) # Random class labels for the batch
>>> output = layer(x, y)
NOTE
The input tensor x is expected to have the shape (batch_size, num_features, height, width). The class labels y should have the shape (batch_size,).
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, y)
Applies the Conditional Instance Normalization operation.
This method performs Conditional Instance Normalization on the input tensor x based on the class label y. The operation normalizes the input tensor using the instance normalization technique and applies learnable scaling and shifting parameters based on the class embedding.
- Parameters:
- x (torch.Tensor) – The input tensor of shape (N, C, H, W) where:
- N is the batch size,
- C is the number of channels,
- H is the height,
- W is the width.
- y (torch.Tensor) – The class label tensor of shape (N,) where each element corresponds to the class index for the respective input in x.
- x (torch.Tensor) – The input tensor of shape (N, C, H, W) where:
- Returns: The output tensor after applying Conditional Instance Normalization, with the same shape as the input tensor x.
- Return type: torch.Tensor
####### Examples
>>> norm_layer = ConditionalInstanceNorm2dPlus(num_features=64, num_classes=10)
>>> input_tensor = torch.randn(8, 64, 32, 32) # Example input
>>> class_labels = torch.randint(0, 10, (8,)) # Random class labels
>>> output_tensor = norm_layer(input_tensor, class_labels)
>>> print(output_tensor.shape) # Should be [8, 64, 32, 32]
NOTE
The input tensor x must have 4 dimensions, and the class label y must be a 1-dimensional tensor with class indices ranging from 0 to num_classes - 1.
- Raises:
- IndexError – If the class index in y is out of bounds for the embedding
- layer. –