espnet2.enh.layers.tcn.TemporalConvNet
espnet2.enh.layers.tcn.TemporalConvNet
class espnet2.enh.layers.tcn.TemporalConvNet(N, B, H, P, X, R, C, Sc=None, out_channel=None, norm_type='gLN', causal=False, pre_mask_nonlinear='linear', mask_nonlinear='relu')
Bases: Module
Temporal Convolutional Network for speech separation.
This class implements the Temporal Convolutional Network (TCN) as proposed in Luo et al. “Conv-tasnet: Surpassing ideal time–frequency magnitude masking for speech separation.” The architecture is designed to perform speech separation by estimating masks for different speakers from a mixture of audio signals.
C
Number of speakers.
mask_nonlinear
Non-linear function used to generate masks.
skip_connection
Boolean indicating if skip connections are used.
out_channel
Number of output channels.
- Parameters:
- N – Number of filters in the autoencoder.
- B – Number of channels in the bottleneck 1x1-conv block.
- H – Number of channels in convolutional blocks.
- P – Kernel size in convolutional blocks.
- X – Number of convolutional blocks in each repeat.
- R – Number of repeats.
- C – Number of speakers.
- Sc – Number of channels in skip-connection paths’ 1x1-conv blocks.
- out_channel – Number of output channels; if None, N is used.
- norm_type – Type of normalization; options include “BN”, “gLN”, “cLN”.
- causal – Boolean indicating if the model is causal.
- pre_mask_nonlinear – Non-linear function before mask generation.
- mask_nonlinear – Non-linear function to generate the mask.
- Returns: Estimated masks for the speakers.
- Return type: est_mask
- Raises:ValueError – If an unsupported mask non-linear function is provided.
####### Examples
>>> model = TemporalConvNet(N=64, B=32, H=128, P=3, X=4, R=2, C=2)
>>> mixture_w = torch.randn(10, 64, 100) # Batch size 10, 64 channels, 100 length
>>> est_mask = model(mixture_w)
>>> print(est_mask.shape) # Should output: torch.Size([10, 2, 64, 100])
Basic Module of tasnet.
- Parameters:
- N – Number of filters in autoencoder
- B – Number of channels in bottleneck 1 * 1-conv block
- H – Number of channels in convolutional blocks
- P – Kernel size in convolutional blocks
- X – Number of convolutional blocks in each repeat
- R – Number of repeats
- C – Number of speakers
- Sc – Number of channels in skip-connection paths’ 1x1-conv blocks
- out_channel – Number of output channels if it is None, N will be used instead.
- norm_type – BN, gLN, cLN
- causal – causal or non-causal
- pre_mask_nonlinear – the non-linear function before masknet
- mask_nonlinear – use which non-linear function to generate mask
forward(mixture_w)
Perform forward pass of the Temporal Convolutional Network.
This method processes the input mixture of audio signals and estimates the masks for each speaker using the temporal convolutional network. The input is expected to be a tensor of shape [M, N, K], where:
- M is the batch size,
- N is the number of input channels (filters), and
- K is the sequence length.
- Parameters:
- mixture_w (torch.Tensor) – A tensor of shape [M, N, K], where M is
- size (the batch)
- channels (N is the number of input)
- the (and K is)
- length. (sequence)
- Returns: A tensor of shape [M, C, N, K] representing the estimated masks for each speaker, where C is the number of speakers.
- Return type: torch.Tensor
- Raises:ValueError – If an unsupported mask non-linear function is specified.
####### Examples
>>> model = TemporalConvNet(N=64, B=32, H=64, P=3, X=4, R=2, C=2)
>>> mixture = torch.randn(10, 64, 100) # Example input
>>> estimated_masks = model.forward(mixture)
>>> print(estimated_masks.shape) # Output: torch.Size([10, 2, 64, 100])