espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalVarianceNorm2d
espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalVarianceNorm2d
class espnet2.enh.layers.ncsnpp_utils.normalization.ConditionalVarianceNorm2d(num_features, num_classes, bias=False)
Bases: Module
Conditional Variance Normalization layer.
This layer normalizes the input tensor based on its variance, conditioned on class embeddings. The normalization is performed by scaling the input tensor using learnable parameters that are dependent on the class of the input.
num_features
The number of features (channels) in the input.
- Type: int
bias
A flag indicating whether to include bias in the normalization. If True, the layer will use an embedding to learn the scale.
- Type: bool
embed
An embedding layer to map class indices to scale parameters.
Type: nn.Embedding
Parameters:
- num_features (int) – Number of features (channels) in the input tensor.
- num_classes (int) – Number of classes for conditional normalization.
- bias (bool , optional) – If True, include a bias term. Defaults to False.
Returns: The normalized output tensor.
Return type: Tensor
####### Examples
>>> import torch
>>> model = ConditionalVarianceNorm2d(num_features=3, num_classes=10)
>>> x = torch.randn(4, 3, 32, 32) # Batch of 4, 3 channels, 32x32
>>> y = torch.tensor([0, 1, 2, 3]) # Class indices for the batch
>>> output = model(x, y)
>>> output.shape
torch.Size([4, 3, 32, 32])
NOTE
The variance is computed across the spatial dimensions (height, width) of the input tensor. A small constant (1e-5) is added to the variance for numerical stability during the normalization.
- Raises:ValueError – If the input tensor does not have the expected number of dimensions.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, y)
Apply the Conditional Variance Normalization to the input tensor.
This method computes the variance of the input tensor x across the spatial dimensions (height and width), normalizes x using the computed variance, and scales the normalized tensor using the embeddings obtained from the class indices y. The embeddings are used to control the scaling factor for the normalized output.
- Parameters:
- x (torch.Tensor) – The input tensor of shape (N, C, H, W), where N is the batch size, C is the number of channels, H is the height, and W is the width.
- y (torch.Tensor) – The tensor of class indices of shape (N,) that determines the scaling factors for each input in the batch.
- Returns: The output tensor of shape (N, C, H, W) after applying the Conditional Variance Normalization.
- Return type: torch.Tensor
####### Examples
>>> import torch
>>> model = ConditionalVarianceNorm2d(num_features=64, num_classes=10)
>>> x = torch.randn(8, 64, 32, 32) # A batch of 8 images
>>> y = torch.randint(0, 10, (8,)) # Random class indices for batch
>>> output = model(x, y)
>>> print(output.shape)
torch.Size([8, 64, 32, 32])
NOTE
The variance is computed with a small constant added to avoid division by zero. The output will have the same shape as the input tensor.