espnet2.enh.layers.tcndenseunet.TCNDenseUNet
espnet2.enh.layers.tcndenseunet.TCNDenseUNet
class espnet2.enh.layers.tcndenseunet.TCNDenseUNet(n_spk=1, in_freqs=257, mic_channels=1, hid_chans=32, hid_chans_dense=32, ksz_dense=(3, 3), ksz_tcn=3, tcn_repeats=4, tcn_blocks=7, tcn_channels=384, activation=<class 'torch.nn.modules.activation.ELU'>)
Bases: Module
TCNDenseNet block from iNeuBe.
Reference: Lu, Y. J., Cornell, S., Chang, X., Zhang, W., Li, C., Ni, Z., … & Watanabe, S. Towards Low-Distortion Multi-Channel Speech Enhancement: The ESPNET-Se Submission to the L3DAS22 Challenge. ICASSP 2022 p. 9201-9205.
- Parameters:
- n_spk (int) – Number of output sources/speakers.
- in_freqs (int) – Number of complex STFT frequencies.
- mic_channels (int) – Number of microphones channels (only fixed-array geometry supported).
- hid_chans (int) – Number of channels in the subsampling/upsampling conv layers.
- hid_chans_dense (int) – Number of channels in the densenet layers (reduce this to reduce VRAM requirements).
- ksz_dense (tuple) – Kernel size in the densenet layers through iNeuBe.
- ksz_tcn (int) – Kernel size in the TCN submodule.
- tcn_repeats (int) – Number of repetitions of blocks in the TCN submodule.
- tcn_blocks (int) – Number of blocks in the TCN submodule.
- tcn_channels (int) – Number of channels in the TCN submodule.
- activation (callable) – Activation function to use in the whole iNeuBe model, you can use any torch supported activation e.g. ‘relu’ or ‘elu’.
n_spk
Number of output sources/speakers.
- Type: int
in_channels
Number of input frequencies.
- Type: int
mic_channels
Number of microphone channels.
- Type: int
encoder
List of encoder layers.
- Type: torch.nn.ModuleList
tcn
TCN block composed of multiple TCNResBlocks.
- Type: torch.nn.Sequential
decoder
List of decoder layers.
- Type: torch.nn.ModuleList
####### Examples
>>> model = TCNDenseUNet(n_spk=2, in_freqs=257, mic_channels=1)
>>> input_tensor = torch.randn(8, 2, 1, 257) # Batch size 8
>>> output = model(input_tensor)
>>> print(output.shape) # Should output a shape of [8, 2, F] where F is the output frequency
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(tf_rep)
Forward pass through the TCNDenseUNet model.
- Parameters:tf_rep (torch.Tensor) – A 4D tensor representing the multi-channel complex Short-Time Fourier Transform (STFT) of the mixture. The shape of the tensor should be [B, T, C, F], where:
- B is the batch size,
- T is the number of frames,
- C is the number of microphone channels,
- F is the number of frequencies.
- Returns: A complex 3D tensor representing the monaural : STFT of the targets. The shape of the output tensor is [B, T, F], where:
- B is the batch size,
- T is the number of frames,
- F is the number of frequencies.
- Return type: out (torch.Tensor)
####### Examples
>>> model = TCNDenseUNet(n_spk=2, in_freqs=257, mic_channels=1)
>>> input_tensor = torch.randn(8, 100, 1, 257) # Example input
>>> output = model(input_tensor)
>>> print(output.shape) # Output shape: [8, 100, 257]
NOTE
The input tensor should be formatted correctly to ensure proper functioning of the model. Ensure that the number of microphone channels matches the expected input shape.
- Raises:
- AssertionError – If the number of microphone channels in the input
- tensor does not match the expected number of microphone channels. –