espnet2.enh.layers.tcndenseunet.DenseBlock
espnet2.enh.layers.tcndenseunet.DenseBlock
class espnet2.enh.layers.tcndenseunet.DenseBlock(in_channels, out_channels, num_freqs, pre_blocks=2, freq_proc_blocks=1, post_blocks=2, ksz=(3, 3), activation=<class 'torch.nn.modules.activation.ELU'>, hid_chans=32)
Bases: Module
Single DenseNet block as used in iNeuBe model.
This class implements a DenseNet block that consists of multiple convolutional layers. It processes input tensors assumed to be in the format [batch, image_channels, frames, freq] and is designed for use in the iNeuBe model.
- Parameters:
- in_channels (int) – Number of input channels (image axis).
- out_channels (int) – Number of output channels (image axis).
- num_freqs (int) – Number of complex frequencies in the input STFT complex image-like tensor. The input is batch, image_channels, frames, freqs.
- pre_blocks (int) – Number of dense blocks before point-wise convolution block over frequency axis (default: 2).
- freq_proc_blocks (int) – Number of frequency axis processing blocks (default: 1).
- post_blocks (int) – Number of dense blocks after point-wise convolution block over frequency axis (default: 2).
- ksz (tuple) – Kernel size used in DenseNet Conv2D layers (default: (3, 3)).
- activation (callable) – Activation function to use in the whole iNeuBe model. You can use any torch supported activation (e.g., ‘relu’ or ‘elu’) (default: torch.nn.ELU).
- hid_chans (int) – Number of hidden channels in DenseNet Conv2D (default: 32).
####### Examples
>>> dense_block = DenseBlock(
... in_channels=64,
... out_channels=128,
... num_freqs=257,
... pre_blocks=2,
... freq_proc_blocks=1,
... post_blocks=2,
... ksz=(3, 3),
... activation=torch.nn.ReLU,
... hid_chans=32
... )
>>> input_tensor = torch.randn(8, 64, 100, 257) # [batch, channels, frames, freqs]
>>> output = dense_block(input_tensor)
>>> output.shape
torch.Size([8, 128, 100, 257])
- Raises:AssertionError – If post_blocks or pre_blocks is less than 1.
NOTE
The output of this block can be further processed in a network designed for tasks such as speech enhancement or audio signal processing.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input)
Forward pass through the TCNDenseUNet.
This method processes the input tensor representing a multi-channel complex Short-Time Fourier Transform (STFT) of a mixture and produces a monaural STFT of the target signals.
Parameters:tf_rep (torch.Tensor) –
A 4D tensor representing the multi-channel complex STFT of the mixture. The expected shape is [B, T, C, F] where:
B = batch size, T = number of frames, C = number of microphones, F = number of frequencies.
Returns: A complex 3D tensor representing the : monaural STFT of the targets. The shape is [B, T, F] where: : B = batch size, T = number of frames, F = number of frequencies.
Return type: out (torch.Tensor)
####### Examples
>>> model = TCNDenseUNet(n_spk=2, in_freqs=257, mic_channels=1)
>>> input_tensor = torch.randn(4, 100, 1, 257) # Example input
>>> output = model.forward(input_tensor)
>>> print(output.shape)
torch.Size([4, 100, 257]) # Output shape
NOTE
The input tensor should contain complex values as separate real and imaginary parts. This function concatenates the real and imaginary parts and reshapes them for processing.
- Raises:
- AssertionError – If the number of microphones in the input does
- not match the expected number of microphone channels. –