espnet2.enh.layers.ncsnpp_utils.layers.CondMSFBlock
espnet2.enh.layers.ncsnpp_utils.layers.CondMSFBlock
class espnet2.enh.layers.ncsnpp_utils.layers.CondMSFBlock(in_planes, features, num_classes, normalizer)
Bases: Module
Conditional Multi-Scale Fusion Block.
This block is designed for multi-scale feature fusion in a conditional manner, typically used in tasks such as image generation and enhancement. It combines multiple input feature maps using convolutions and normalizes them based on class information.
convs
List of convolutional layers for each input plane.
- Type: nn.ModuleList
norms
List of normalization layers for each input plane.
- Type: nn.ModuleList
features
The number of output features for the block.
- Type: int
normalizer
A callable normalization function.
Type: callable
Parameters:
- in_planes (list or tuple) – A list or tuple containing the number of input channels for each input feature map.
- features (int) – The number of output features after fusion.
- num_classes (int) – The number of classes for conditional normalization.
- normalizer (callable) – A normalization layer or function to be applied after convolutions.
Forward: : x (list of tensors): List of input feature maps, each of shape : (B, C, H, W), where B is the batch size, C is the number of channels, H is the height, and W is the width. <br/> y (tensor): A tensor containing class information, used for : conditional normalization. <br/> shape (tuple): The target shape (H, W) for the output feature map.
- Returns: A tensor of shape (B, features, height, width) containing the fused feature map.
- Return type: tensor
####### Examples
>>> cond_msf = CondMSFBlock([32, 64], features=128, num_classes=10,
... normalizer=ConditionalInstanceNorm2dPlus)
>>> input_features = [torch.randn(8, 32, 64, 64), torch.randn(8, 64, 64, 64)]
>>> class_info = torch.randint(0, 10, (8,))
>>> output = cond_msf(input_features, class_info, (32, 32))
>>> output.shape
torch.Size([8, 128, 32, 32])
NOTE
The normalization layers expect the input tensors to be of shape (B, C, H, W). Ensure that the input tensors are appropriately shaped before passing them to the forward method.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(xs, y, shape)
Performs the forward pass of the Conditional Residual Block.
This method takes the input tensor x and the conditioning tensor y, applies normalization, activation, and convolution operations in a residual manner across multiple blocks and stages. The output is a combination of the processed input and the residual.
- Parameters:
- x (torch.Tensor) – The input tensor of shape (B, C, H, W), where B is the batch size, C is the number of channels, and H and W are the height and width of the input.
- y (torch.Tensor) – The conditioning tensor of shape (B, num_classes), where B is the batch size and num_classes is the number of classes for conditional normalization.
- Returns: The output tensor after applying the forward operations, : with the same shape as the input tensor x.
- Return type: torch.Tensor
####### Examples
>>> block = CondRCUBlock(features=64, n_blocks=2, n_stages=2,
... num_classes=10, normalizer=SomeNormalizer)
>>> x = torch.randn(8, 64, 32, 32) # Batch of 8, 64 channels, 32x32
>>> y = torch.randint(0, 10, (8,)) # Batch of 8, class indices
>>> output = block(x, y)
>>> print(output.shape)
torch.Size([8, 64, 32, 32])
NOTE
This block uses conditional normalization, which allows the model to adapt its parameters based on the class information provided by the conditioning tensor y.
- Raises:RuntimeError – If the shapes of the input tensor x and conditioning tensor y do not match the expected dimensions.