espnet2.enh.separator.tcn_separator.TCNSeparator
espnet2.enh.separator.tcn_separator.TCNSeparator
class espnet2.enh.separator.tcn_separator.TCNSeparator(input_dim: int, num_spk: int = 2, predict_noise: bool = False, layer: int = 8, stack: int = 3, bottleneck_dim: int = 128, hidden_dim: int = 512, kernel: int = 3, causal: bool = False, norm_type: str = 'gLN', nonlinear: str = 'relu', pre_mask_nonlinear: str = 'prelu', masking: bool = True)
Bases: AbsSeparator
TCNSeparator is a temporal convolutional network-based separator for speech
enhancement. It utilizes a TemporalConvNet to separate audio signals from multiple speakers, optionally predicting noise in the process.
num_spk
Number of speakers in the input audio.
- Type: int
predict_noise
Flag to indicate whether to output the estimated noise signal.
- Type: bool
masking
Flag to choose between masking and mapping-based methods.
Type: bool
Parameters:
- input_dim (int) – Input feature dimension.
- num_spk (int , optional) – Number of speakers (default is 2).
- predict_noise (bool , optional) – Whether to output the estimated noise signal (default is False).
- layer (int , optional) – Number of layers in each stack (default is 8).
- stack (int , optional) – Number of stacks (default is 3).
- bottleneck_dim (int , optional) – Bottleneck dimension (default is 128).
- hidden_dim (int , optional) – Number of convolution channels (default is 512).
- kernel (int , optional) – Kernel size (default is 3).
- causal (bool , optional) – Whether to use causal convolutions (default is False).
- norm_type (str , optional) – Normalization type, choose from ‘BN’, ‘gLN’, ‘cLN’ (default is ‘gLN’).
- nonlinear (str , optional) – Nonlinear function for mask estimation, select from ‘relu’, ‘tanh’, ‘sigmoid’, ‘linear’ (default is ‘relu’).
- pre_mask_nonlinear (str , optional) – Non-linear function before masknet (default is ‘prelu’).
- masking (bool , optional) – Whether to use the masking or mapping based method (default is True).
Raises:ValueError – If the specified nonlinear function is not supported.
######### Examples
Initialize TCNSeparator
separator = TCNSeparator(input_dim=256, num_spk=2)
Forward pass with input tensor and lengths
input_tensor = torch.randn(4, 100, 256) # [Batch, Time, Feature] ilens = torch.tensor([100, 100, 100, 100]) # input lengths masked, ilens, others = separator(input_tensor, ilens)
Forward streaming
streaming_output, buffer, others_streaming = separator.forward_streaming(
input_frame=torch.randn(4, 1, 256) # [Batch, 1, Feature]
)
Temporal Convolution Separator
- Parameters:
- input_dim – input feature dimension
- num_spk – number of speakers
- predict_noise – whether to output the estimated noise signal
- layer – int, number of layers in each stack.
- stack – int, number of stacks
- bottleneck_dim – bottleneck dimension
- hidden_dim – number of convolution channel
- kernel – int, kernel size.
- causal – bool, defalut False.
- norm_type – str, choose from ‘BN’, ‘gLN’, ‘cLN’
- nonlinear – the nonlinear function for mask estimation, select from ‘relu’, ‘tanh’, ‘sigmoid’, ‘linear’
- pre_mask_nonlinear – the non-linear function before masknet
- masking – whether to use the masking or mapping based method
forward(input: Tensor | ComplexTensor, ilens: Tensor, additional: Dict | None = None) → Tuple[List[Tensor | ComplexTensor], Tensor, OrderedDict]
Forward pass for the Temporal Convolution Separator.
This method processes the input features through the Temporal Convolution Network (TCN) to generate masked output signals for each speaker, along with additional predicted data, such as noise estimates if enabled.
- Parameters:
- input (Union *[*torch.Tensor , ComplexTensor ]) – Encoded feature of shape [B, T, N], where B is the batch size, T is the time dimension, and N is the number of features. The input can be a real tensor or a complex tensor.
- ilens (torch.Tensor) – Input lengths of shape [Batch], indicating the actual lengths of each input sequence.
- additional (Optional *[*Dict ]) – Additional data included in the model, which is currently not used in this implementation. Defaults to None.
- Returns:
masked (List[Union[torch.Tensor, ComplexTensor]]): List of tensors, each representing the masked output for the speakers, shape [(B, T, N), …].
ilens (torch.Tensor): Input lengths of shape (B,), as passed.
others (OrderedDict): Dictionary containing other predicted data, e.g. masks: OrderedDict[
’mask_spk1’: torch.Tensor(Batch, Frames, Freq), ‘mask_spk2’: torch.Tensor(Batch, Frames, Freq), … ‘mask_spkn’: torch.Tensor(Batch, Frames, Freq),
]
- Return type: Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]
######### Examples
Example of using the forward method
separator = TCNSeparator(input_dim=128, num_spk=2) input_tensor = torch.randn(10, 50, 128) # Example input ilens = torch.tensor([50] * 10) # Example lengths masked, ilens, others = separator.forward(input_tensor, ilens)
NOTE
Ensure that the input dimensions and types are consistent with the expected shapes and types as outlined in the arguments.
- Raises:
- ValueError – If the input nonlinear activation function is not one of the
- supported types – ‘sigmoid’, ‘relu’, ‘tanh’, or ‘linear’.
forward_streaming(input_frame: Tensor, buffer=None)
Forward streaming for the Temporal Convolution Separator.
This method processes an input frame for real-time audio separation. It maintains a buffer to accommodate the temporal context required by the temporal convolution network. The input is rolled into the buffer to simulate streaming input.
- Parameters:
- input_frame (torch.Tensor) – The input audio frame of shape (B, 1, N) where B is the batch size and N is the number of features.
- buffer (torch.Tensor , optional) – The buffer containing past frames of shape (B, receptive_field, N). If None, a new buffer will be initialized.
- Returns: Tuple[List[Union[torch.Tensor, ComplexTensor]], : > torch.Tensor, OrderedDict]:
- masked (List[Union[torch.Tensor, ComplexTensor]]): : List of masked output tensors for each speaker.
- buffer (torch.Tensor): The updated buffer after : processing the input frame.
- others (OrderedDict): Additional outputs such as : masks for each speaker.
######### Examples
>>> separator = TCNSeparator(input_dim=64, num_spk=2)
>>> input_frame = torch.randn(4, 1, 64) # Batch size of 4
>>> masked_output, updated_buffer, additional_outputs =
... separator.forward_streaming(input_frame)
NOTE
This method requires that the TemporalConvNet has been properly initialized and the input frame has the correct shape.
- Raises:ValueError – If the input_frame does not have the correct shape.
property num_spk