espnet2.enh.layers.ncsnpp_utils.layers.MSFBlock
espnet2.enh.layers.ncsnpp_utils.layers.MSFBlock
class espnet2.enh.layers.ncsnpp_utils.layers.MSFBlock(in_planes, features)
Bases: Module
Multi-Scale Fusion Block for feature aggregation across multiple scales.
This block is designed to take multiple input feature maps of different channels and aggregate them into a single output feature map with a specified number of output features. It applies a 3x3 convolution to each input feature map, followed by an interpolation to a common output shape, and sums the results.
convs
A list of convolutional layers for each input feature map.
- Type: nn.ModuleList
features
The number of output features for the block.
Type: int
Parameters:
- in_planes (list or tuple) – A list or tuple of integers representing the number of input channels for each feature map.
- features (int) – The number of output channels after the convolution.
Returns: A tensor containing the aggregated output feature map.
Return type: torch.Tensor
####### Examples
>>> msf_block = MSFBlock([64, 128, 256], features=128)
>>> input_features = [torch.randn(1, 64, 32, 32),
... torch.randn(1, 128, 16, 16),
... torch.randn(1, 256, 8, 8)]
>>> output = msf_block(input_features, shape=(32, 32))
>>> output.shape
torch.Size([1, 128, 32, 32])
NOTE
The input feature maps should be provided as a list or tuple, and all feature maps should be of the same batch size.
- Raises:AssertionError – If the input in_planes is not a list or tuple.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(xs, shape)
Multi-Scale Feature Block for processing multiple input feature maps.
This block takes a list of input feature maps and applies 3x3 convolutions to each one, followed by bilinear interpolation to combine them into a single output feature map.
convs
A list of convolutional layers for each input feature map.
- Type: nn.ModuleList
features
The number of output feature channels.
Type: int
Parameters:
- in_planes (list or tuple) – A list or tuple of integers representing the number of input channels for each feature map.
- features (int) – The number of output feature channels.
Returns: A tensor containing the combined output feature map.
Return type: torch.Tensor
####### Examples
>>> msf_block = MSFBlock([64, 128], features=256)
>>> input_tensor1 = torch.randn(1, 64, 32, 32)
>>> input_tensor2 = torch.randn(1, 128, 32, 32)
>>> output = msf_block([input_tensor1, input_tensor2], shape=(64, 64))
>>> output.shape
torch.Size([1, 256, 64, 64])