espnet2.enh.layers.tcndenseunet.TCNResBlock
espnet2.enh.layers.tcndenseunet.TCNResBlock
class espnet2.enh.layers.tcndenseunet.TCNResBlock(in_chan, out_chan, ksz=3, stride=1, dilation=1, activation=<class 'torch.nn.modules.activation.ELU'>)
Bases: Module
Single depth-wise separable TCN block as used in iNeuBe TCN.
This block implements a depth-wise separable convolution followed by a point-wise convolution. It applies group normalization and an activation function to the input features, allowing for efficient processing of temporal data.
- Parameters:
- in_chan (int) – Number of input feature channels.
- out_chan (int) – Number of output feature channels.
- ksz (int , optional) – Kernel size. Defaults to 3.
- stride (int , optional) – Stride in depth-wise convolution. Defaults to 1.
- dilation (int , optional) – Dilation in depth-wise convolution. Defaults to 1.
- activation (callable , optional) – Activation function to use in the whole iNeuBe model, you can use any torch supported activation e.g. ‘relu’ or ‘elu’. Defaults to torch.nn.ELU.
####### Examples
>>> tcn_block = TCNResBlock(in_chan=64, out_chan=128)
>>> input_tensor = torch.randn(32, 64, 100) # [batch_size, channels, frames]
>>> output_tensor = tcn_block(input_tensor)
>>> output_tensor.shape
torch.Size([32, 128, 100]) # Output will have the shape of [B, C, F]
- Returns: Output tensor of shape [B, out_chan, F] where B is the batch size, out_chan is the number of output channels, and F is the number of frames.
- Return type: torch.Tensor
NOTE
The input tensor should be 3D with shape [B, C, F] where B is the batch size, C is the number of channels, and F is the number of frames.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(inp)
forward.
- Parameters:tf_rep (torch.Tensor) – 4D tensor (multi-channel complex STFT of mixture) of shape [B, T, C, F], where B is the batch size, T is the number of frames, C is the number of microphone channels, and F is the number of frequencies.
- Returns: complex 3D tensor representing the monaural STFT : of the targets, with shape [B, T, F], where B is the batch size, T is the number of frames, and F is the number of frequencies.
- Return type: out (torch.Tensor)
####### Examples
>>> import torch
>>> model = TCNDenseUNet()
>>> tf_rep = torch.randn(8, 64, 2, 257) # 8 samples, 64 frames, 2 mics, 257 freqs
>>> output = model(tf_rep)
>>> print(output.shape) # Expected output shape: [8, 2, 257]
NOTE
The input tensor must be permuted to match the expected shape before being passed to this method. The input tensor is assumed to be a multi-channel complex STFT representation.