espnet2.enh.layers.tcndenseunet.Conv2DActNorm
espnet2.enh.layers.tcndenseunet.Conv2DActNorm
class espnet2.enh.layers.tcndenseunet.Conv2DActNorm(in_channels, out_channels, ksz=(3, 3), stride=(1, 2), padding=(1, 0), upsample=False, activation=<class 'torch.nn.modules.activation.ELU'>)
Bases: Module
Conv2DActNorm is a building block for a convolutional layer followed by an
activation function and instance normalization.
This module combines a 2D convolution operation with an activation function and group normalization to form a reusable component in neural networks, particularly for tasks involving image or spectrogram data.
layer
A sequential container that holds the convolution, activation, and normalization layers.
Type: torch.nn.Sequential
Parameters:
- in_channels (int) – Number of input channels.
- out_channels (int) – Number of output channels.
- ksz (tuple) – Kernel size for the convolution. Default is (3, 3).
- stride (tuple) – Stride for the convolution. Default is (1, 2).
- padding (tuple) – Padding for the convolution. Default is (1, 0).
- upsample (bool) – If True, uses transposed convolution for upsampling. Default is False.
- activation (callable) – Activation function to use. Default is torch.nn.ELU.
Returns: The output tensor after applying convolution, activation, and normalization.
Return type: torch.Tensor
####### Examples
>>> conv_layer = Conv2DActNorm(1, 16)
>>> input_tensor = torch.randn(1, 1, 64, 32) # (batch, channels, height, width)
>>> output_tensor = conv_layer(input_tensor)
>>> output_tensor.shape
torch.Size([1, 16, 32, 16]) # After convolution and downsampling
- Raises:
- ValueError – If input parameters are invalid (e.g., negative channel
- sizes**)****.** –
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(inp)
Forward pass of the TCNDenseUNet.
- Parameters:tf_rep (torch.Tensor) – 4D tensor (multi-channel complex STFT of mixture) of shape [B, T, C, F] where B is batch size, T is number of frames, C is the number of microphones, and F is the number of frequencies.
- Returns: Complex 3D tensor representing the monaural STFT : of the targets, with shape [B, T, F] where B is batch size, T is number of frames, and F is number of frequencies.
- Return type: out (torch.Tensor)
####### Examples
>>> model = TCNDenseUNet(n_spk=2, in_freqs=257, mic_channels=1)
>>> input_tensor = torch.randn(4, 10, 1, 257) # Example input
>>> output = model(input_tensor)
>>> print(output.shape) # Should be [4, 10, 2, 257] for 2 speakers
NOTE
The input tensor should be in the format expected by the model, which is a multi-channel complex STFT representation.
- Raises:
- AssertionError – If the number of microphone channels in the input
- tensor does not match the expected number of microphone channels –
- specified during model initialization. –