espnet2.enh.layers.tcndenseunet.FreqWiseBlock
espnet2.enh.layers.tcndenseunet.FreqWiseBlock
class espnet2.enh.layers.tcndenseunet.FreqWiseBlock(in_channels, num_freqs, out_channels, activation=<class 'torch.nn.modules.activation.ELU'>)
Bases: Module
FreqWiseBlock, see iNeuBe paper.
Block that applies pointwise 2D convolution over STFT-like image tensor on frequency axis. The input is assumed to be [batch, image_channels, frames, freq].
bottleneck
A convolutional layer with activation and normalization for processing input channels.
- Type:Conv2DActNorm
freq_proc
A convolutional layer with activation and normalization for processing frequency channels.
Type:Conv2DActNorm
Parameters:
- in_channels (int) – Number of input channels (image axis).
- num_freqs (int) – Number of complex frequencies in the input STFT complex image-like tensor.
- out_channels (int) – Number of output channels (image axis).
- activation (callable) – Activation function to use, default is torch.nn.ELU.
Returns: The output tensor after applying the frequency-wise : processing.
Return type: torch.Tensor
####### Examples
>>> import torch
>>> block = FreqWiseBlock(in_channels=64, num_freqs=128, out_channels=32)
>>> input_tensor = torch.randn(10, 64, 100, 128) # [batch, channels, frames, freq]
>>> output_tensor = block(input_tensor)
>>> output_tensor.shape
torch.Size([10, 32, 100, 128])
NOTE
This block is designed to operate on STFT-like tensors, where the frequency axis is processed independently.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(inp)
Forward pass for the TCNDenseUNet model.
This method processes a 4D tensor representing the multi-channel complex Short-Time Fourier Transform (STFT) of a mixture signal. The output is a complex 3D tensor representing the monaural STFT of the targets.
- Parameters:tf_rep (torch.Tensor) – 4D tensor (multi-channel complex STFT of mixture) of shape [B, T, C, F] where B is the batch size, T is the number of frames, C is the number of microphone channels, and F is the number of frequencies.
- Returns: Complex 3D tensor representing the monaural STFT : of the targets with shape [B, T, F] where B is the batch size, T is the number of frames, and F is the number of frequencies.
- Return type: out (torch.Tensor)
####### Examples
>>> model = TCNDenseUNet(n_spk=2, in_freqs=257, mic_channels=1)
>>> mixture = torch.randn(8, 100, 1, 257) # Batch of 8, 100 frames
>>> output = model.forward(mixture)
>>> print(output.shape) # Output shape should be [8, 100, 257]
NOTE
The input tensor is expected to be in the shape [B, T, C, F]. The function will permute and reshape it accordingly to match the expected input format of the model.
- Raises:
- AssertionError – If the number of microphone channels in the input
- tensor does not match the expected number of microphone channels –
- for the model. –