espnet2.enh.layers.ncsnpp_utils.layerspp.Combine
espnet2.enh.layers.ncsnpp_utils.layerspp.Combine
class espnet2.enh.layers.ncsnpp_utils.layerspp.Combine(dim1, dim2, method='cat')
Bases: Module
Combine information from skip connections.
This module combines two input tensors using a specified method. The default method is concatenation, but addition can also be used. It applies a 1x1 convolution to the first input tensor before combining it with the second input tensor.
Conv_0
A 1x1 convolution layer applied to the first input tensor.
- Type: nn.Module
method
The method to combine the inputs. Can be “cat” for concatenation or “sum” for addition.
Type: str
Parameters:
- dim1 (int) – The number of input channels for the first tensor.
- dim2 (int) – The number of input channels for the second tensor.
- method (str , optional) – The method to combine the inputs. Defaults to “cat”.
Returns: The combined output tensor after applying the specified method.
Return type: torch.Tensor
Raises:ValueError – If the method is not recognized.
####### Examples
>>> combine = Combine(dim1=64, dim2=32, method="cat")
>>> x = torch.randn(1, 64, 16, 16) # First input tensor
>>> y = torch.randn(1, 32, 16, 16) # Second input tensor
>>> output = combine(x, y) # Combines using concatenation
>>> combine_sum = Combine(dim1=64, dim2=32, method="sum")
>>> output_sum = combine_sum(x, y) # Combines using addition
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, y)
Combines two inputs using a specified method.
This method applies a convolution to the first input x, then combines the result with the second input y based on the specified method, which can either be concatenation or summation.
- Parameters:
- x (torch.Tensor) – The first input tensor, which is passed through a convolution layer. The shape should be (B, C1, H, W) where B is the batch size, C1 is the number of channels, and H and W are the height and width of the input.
- y (torch.Tensor) – The second input tensor to be combined with the processed x. The shape should be (B, C2, H, W) where C2 should match the output channels of the convolution layer.
- Returns: The combined output tensor. If method is “cat”, the output shape will be (B, C1 + C2, H, W). If method is “sum”, the output shape will be (B, C1, H, W).
- Return type: torch.Tensor
- Raises:ValueError – If the specified method is not recognized.
####### Examples
>>> combine_layer = Combine(dim1=64, dim2=32, method='cat')
>>> x = torch.randn(1, 64, 8, 8) # (B, C1, H, W)
>>> y = torch.randn(1, 32, 8, 8) # (B, C2, H, W)
>>> output = combine_layer(x, y)
>>> output.shape
torch.Size([1, 96, 8, 8]) # Output shape when method is 'cat'
>>> combine_layer = Combine(dim1=64, dim2=32, method='sum')
>>> output = combine_layer(x, y)
>>> output.shape
torch.Size([1, 64, 8, 8]) # Output shape when method is 'sum'