espnet2.enh.layers.ncsnpp_utils.layers.CondCRPBlock
espnet2.enh.layers.ncsnpp_utils.layers.CondCRPBlock
class espnet2.enh.layers.ncsnpp_utils.layers.CondCRPBlock(features, n_stages, num_classes, normalizer, act=ReLU())
Bases: Module
Conditional Residual Processing Block.
This class implements a conditional residual processing block that applies a series of convolutional layers, along with conditional normalization and pooling operations. It is designed to facilitate enhanced feature extraction and integration for tasks requiring class-specific adaptations.
convs
List of convolutional layers.
- Type: nn.ModuleList
norms
List of normalization layers.
- Type: nn.ModuleList
normalizer
Function used for normalization.
- Type: callable
n_stages
Number of stages in the block.
- Type: int
pool
Average pooling layer.
- Type: nn.AvgPool2d
act
Activation function applied in the block.
Type: callable
Parameters:
- features (int) – Number of input and output features for the convolutional layers.
- n_stages (int) – Number of stages (layers) to be applied in the block.
- num_classes (int) – Number of classes for conditional normalization.
- normalizer (callable) – Normalization function to be applied to the input.
- act (callable , optional) – Activation function (default is nn.ReLU()).
####### Examples
>>> features = 64
>>> n_stages = 3
>>> num_classes = 10
>>> normalizer = ConditionalInstanceNorm2dPlus
>>> block = CondCRPBlock(features, n_stages, num_classes, normalizer)
>>> x = torch.randn(1, 64, 32, 32) # Example input
>>> y = torch.randint(0, num_classes, (1,)) # Example class input
>>> output = block(x, y)
Forward: : The forward method takes an input tensor x and a conditional input y, applies the activation function, processes the input through the normalization and convolution layers, and returns the output tensor.
- Raises:
- ValueError – If the input tensor dimensions do not match the expected
- dimensions. –
NOTE
Ensure that the input tensor x and the conditional tensor y are appropriately shaped to match the model’s expectations.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, y)
Forward pass through the Conditional CRP Block.
This method applies a series of conditional normalization and convolution operations to the input tensor x, conditioned on the input tensor y. The operations are repeated over a specified number of stages, with intermediate pooling applied at each stage.
- Parameters:
- x (torch.Tensor) – The input tensor of shape (batch_size, num_features, height, width).
- y (torch.Tensor) – The conditional input tensor of shape (batch_size, num_classes).
- Returns: The output tensor of the same shape as x, which incorporates the results of the convolutional operations.
- Return type: torch.Tensor
####### Examples
>>> cond_crp_block = CondCRPBlock(features=64, n_stages=3,
... num_classes=10,
... normalizer=ConditionalInstanceNorm2dPlus)
>>> x = torch.randn(8, 64, 32, 32) # Batch of 8, 64 features
>>> y = torch.randint(0, 10, (8,)) # Random class labels
>>> output = cond_crp_block(x, y)
>>> output.shape
torch.Size([8, 64, 32, 32])