espnet2.enh.layers.ncsnpp_utils.layers.CondRefineBlock
espnet2.enh.layers.ncsnpp_utils.layers.CondRefineBlock
class espnet2.enh.layers.ncsnpp_utils.layers.CondRefineBlock(in_planes, features, num_classes, normalizer, act=ReLU(), start=False, end=False)
Bases: Module
Conditional Refine Block for processing input features with class conditioning.
This block applies a series of conditional residual units and pooling layers to refine the input features while taking class information into account. It is particularly useful in tasks where the output needs to be conditioned on class labels, enhancing the feature representation based on the classes.
adapt_convs
A list of conditional residual blocks for adapting input features.
- Type: nn.ModuleList
output_convs
A conditional residual block for producing the output features.
- Type:CondRCUBlock
msf
A multi-scale feature block to combine features from multiple input sources.
- Type:CondMSFBlock
crp
A conditional CRP block for pooling and refining features.
Type:CondCRPBlock
Parameters:
- in_planes (tuple or list) – Number of input channels for each block.
- features (int) – Number of output channels for the final output.
- num_classes (int) – Number of classes for conditional normalization.
- normalizer (callable) – Normalization function to be applied.
- act (nn.Module , optional) – Activation function to use. Defaults to nn.ReLU().
- start (bool , optional) – Flag indicating if this is the start block. Defaults to False.
- end (bool , optional) – Flag indicating if this is the end block. Defaults to False.
####### Examples
>>> cond_refine_block = CondRefineBlock(
... in_planes=(32, 64),
... features=128,
... num_classes=10,
... normalizer=ConditionalInstanceNorm2dPlus
... )
>>> x = (torch.randn(1, 32, 64, 64), torch.randn(1, 64, 32, 32))
>>> y = torch.randint(0, 10, (1,))
>>> output = cond_refine_block(x, y, output_shape=(32, 32))
NOTE
The input xs should be a tuple or list of feature maps to be processed, and y should be a tensor containing the class labels.
- Raises:AssertionError – If xs is not a tuple or list.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(xs, y, output_shape)
Conditional Refine Block for enhancing feature representations.
This block is designed to refine features through a combination of conditional residual units and pooling layers. It utilizes multiple stages of processing, where each stage applies normalization, convolutions, and an activation function. The block is particularly useful in conditional settings, allowing the integration of class information during the refinement process.
n_blocks
Number of input blocks.
- Type: int
adapt_convs
List of conditional residual units for each input block.
- Type: nn.ModuleList
output_convs
Conditional residual unit for output features.
- Type:CondRCUBlock
msf
Multi-scale feature block for combining inputs (if not starting).
- Type:CondMSFBlock
crp
Conditional Residual Pooling block for refining features.
Type:CondCRPBlock
Parameters:
- in_planes (tuple or list) – Number of input feature planes for each block.
- features (int) – Number of feature planes for output.
- num_classes (int) – Number of classes for conditional normalization.
- normalizer (callable) – Normalization function to be used in the conditional residual units.
- act (nn.Module , optional) – Activation function to apply. Defaults to ReLU.
- start (bool , optional) – If True, skip multi-scale feature block. Defaults to False.
- end (bool , optional) – If True, apply additional layers at the end. Defaults to False.
Returns: Refined feature tensor after processing.
Return type: Tensor
####### Examples
>>> cond_refine_block = CondRefineBlock(
... in_planes=(32, 64),
... features=128,
... num_classes=10,
... normalizer=ConditionalInstanceNorm2dPlus
... )
>>> x = (torch.randn(1, 32, 64, 64), torch.randn(1, 64, 64, 64))
>>> y = torch.randint(0, 10, (1,))
>>> output_shape = (128, 32, 32)
>>> output = cond_refine_block(x, y, output_shape)
>>> output.shape
torch.Size([1, 128, 32, 32])
NOTE
Ensure that the input tensors are compatible with the expected dimensions, particularly with respect to the number of channels and spatial dimensions.