espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalBatchNorm2d
espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalBatchNorm2d
class espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalBatchNorm2d(num_features, num_classes, bias=True)
Bases: Module
Conditional Batch Normalization layer.
This layer applies batch normalization conditionally based on the class labels provided. It allows for learning different scaling and shifting parameters for different classes, which can be useful in tasks where the input data can be categorized into distinct classes.
num_features
Number of features (channels) in the input tensor.
- Type: int
bias
Whether to include bias parameters in the normalization.
- Type: bool
bn
Batch normalization layer without affine parameters.
- Type: nn.BatchNorm2d
embed
Embedding layer to learn scaling and bias parameters based on class labels.
Type: nn.Embedding
Parameters:
- num_features (int) – Number of features (channels) in the input tensor.
- num_classes (int) – Number of classes for the conditional embedding.
- bias (bool) – If True, includes bias in the normalization.
Returns: The normalized output tensor.
Return type: Tensor
####### Examples
>>> num_features = 16
>>> num_classes = 10
>>> batch_norm = ConditionalBatchNorm2d(num_features, num_classes)
>>> x = torch.randn(8, num_features, 32, 32) # Batch of 8 images
>>> y = torch.randint(0, num_classes, (8,)) # Random class labels
>>> output = batch_norm(x, y)
NOTE
The scaling and bias parameters are initialized randomly. The scaling parameters are initialized to a normal distribution with mean 1 and standard deviation 0.02, while the bias parameters are initialized to zero.
- Raises:ValueError – If the input tensor dimensions do not match the expected shape for batch normalization.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, y)
Applies the conditional batch normalization to the input tensor.
This method normalizes the input tensor x using batch normalization and applies a conditional scaling and shifting based on the class embeddings indexed by y. If bias is enabled, the method retrieves both scale (gamma) and bias (beta) parameters from the embedding layer and applies them to the normalized output.
- Parameters:
- x (torch.Tensor) – The input tensor of shape (N, C, H, W), where N is the batch size, C is the number of channels, H is the height, and W is the width.
- y (torch.Tensor) – The tensor containing class indices of shape (N,) for which the normalization parameters will be conditioned.
- Returns: The output tensor after applying conditional batch normalization, having the same shape as the input tensor x.
- Return type: torch.Tensor
####### Examples
>>> c_bn = ConditionalBatchNorm2d(num_features=64, num_classes=10)
>>> input_tensor = torch.randn(32, 64, 8, 8) # Batch of 32 images
>>> class_indices = torch.randint(0, 10, (32,)) # Random class indices
>>> output_tensor = c_bn(input_tensor, class_indices)
>>> output_tensor.shape
torch.Size([32, 64, 8, 8])
NOTE
The y tensor must contain valid class indices within the range [0, num_classes - 1].
- Raises:
- RuntimeError – If the input tensor x does not have the expected
- shape or if the class indices y are out of bounds. –