espnet2.enh.separator.uses_separator.USESSeparator
espnet2.enh.separator.uses_separator.USESSeparator
class espnet2.enh.separator.uses_separator.USESSeparator(input_dim: int, num_spk: int = 2, enc_channels: int = 256, bottleneck_size: int = 64, num_blocks: int = 6, num_spatial_blocks: int = 3, ref_channel: int | None = None, segment_size: int = 64, memory_size: int = 20, memory_types: int = 1, rnn_type: str = 'lstm', bidirectional: bool = True, hidden_size: int = 128, att_heads: int = 4, dropout: float = 0.0, norm_type: str = 'cLN', activation: str = 'relu', ch_mode: str | List[str] = 'att', ch_att_dim: int = 256, eps: float = 1e-05, additional: dict = {})
Bases: AbsSeparator
Unconstrained Speech Enhancement and Separation (USES) Network.
This class implements the USES architecture for speech enhancement and separation tasks. It is designed to handle various input conditions and is capable of separating multiple speakers from a mixture.
Reference: : [1] W. Zhang, K. Saijo, Z.-Q. Wang, S. Watanabe, and Y. Qian, “Toward Universal Speech Enhancement for Diverse Input Conditions,” in Proc. ASRU, 2023.
Args: : input_dim (int): Input feature dimension. Not used as the model is : independent of the input size. <br/> num_spk (int): Number of speakers. enc_channels (int): Feature dimension after the Conv1D encoder. bottleneck_size (int): Dimension of the bottleneck feature. Must be a <br/>
multiple of att_heads. <br/> num_blocks (int): Number of processing blocks. num_spatial_blocks (int): Number of processing blocks with channel modeling. ref_channel (int): Reference channel (used in channel modeling modules). segment_size (int): Number of frames in each non-overlapping segment. <br/> This is used to segment long utterances into smaller chunks for efficient processing. <br/> memory_size (int): Group size of global memory tokens. The basic use : of memory tokens is to store the history information from previous segments. The memory tokens are updated by the output of the last block after processing each segment. <br/> memory_types (int): Number of memory token groups. Each group corresponds : to a different type of processing. <br/> rnn_type (str): Select from ‘RNN’, ‘LSTM’, and ‘GRU’. bidirectional (bool): Whether the inter-chunk RNN layers are bidirectional. hidden_size (int): Dimension of the hidden state. att_heads (int): Number of attention heads. dropout (float): Dropout ratio. Default is 0. norm_type (str): Type of normalization to use after each inter- or <br/> intra-chunk NN block. <br/> activation (str): The nonlinear activation function. ch_mode (Union[str, List[str]]): Mode of channel modeling. Select from <br/> “att” and “tac”. <br/> ch_att_dim (int): Dimension of the channel attention. eps (float): Epsilon for layer normalization. additional (dict): Additional parameters for flexibility during inference.
Examples: : ```python
separator = USESSeparator(input_dim=80, num_spk=2) input_tensor = torch.randn(10, 256, 80, 64) # Example input ilens = torch.tensor([64] * 10) # Example input lengths outputs, lengths, others = separator(input_tensor, ilens)
Unconstrained Speech Enhancement and Separation (USES) Network.
Reference: : [1] W. Zhang, K. Saijo, Z.-Q., Wang, S. Watanabe, and Y. Qian, “Toward Universal Speech Enhancement for Diverse Input Conditions,” in Proc. ASRU, 2023.
- Parameters:
input_dim (int) – input feature dimension. Not used as the model is independent of the input size.
num_spk (int) – number of speakers.
enc_channels (int) – feature dimension after the Conv1D encoder.
bottleneck_size (int) – dimension of the bottleneck feature. Must be a multiple of att_heads.
num_blocks (int) – number of processing blocks.
num_spatial_blocks (int) – number of processing blocks with channel modeling.
ref_channel (int) – reference channel (used in channel modeling modules).
segment_size (int) – number of frames in each non-overlapping segment. This is used to segment long utterances into smaller chunks for efficient processing.
memory_size (int) – group size of global memory tokens. The basic use of memory tokens is to store the history information from previous segments. The memory tokens are updated by the output of the last block after processing each segment.
memory_types (int) –
numbre of memory token groups. Each group corresponds to a different type of processing, i.e.,
the first group is used for denoising without dereverberation, the second group is used for denoising with dereverberation,
rnn_type – string, select from ‘RNN’, ‘LSTM’ and ‘GRU’.
bidirectional (bool) – whether the inter-chunk RNN layers are bidirectional.
hidden_size (int) – dimension of the hidden state.
att_heads (int) – number of attention heads.
dropout (float) – dropout ratio. Default is 0.
norm_type – type of normalization to use after each inter- or intra-chunk NN block.
activation – the nonlinear activation function.
ch_mode – str or list, mode of channel modeling. Select from “att” and “tac”.
ch_att_dim (int) – dimension of the channel attention.
ref_channel – Optional[int], index of the reference channel.
eps (float) – epsilon for layer normalization.
forward(input: Tensor | ComplexTensor, ilens: Tensor, additional: Dict | None = None) → Tuple[List[Tensor | ComplexTensor], Tensor, OrderedDict]
Performs the forward pass of the USESSeparator model.
This method processes the input STFT spectrum to separate the sources and produce enhanced audio signals.
Parameters:
input (torch.Tensor or ComplexTensor) –
STFT spectrum with shape [B, T, (C,) F (,2)], where:
- B is the batch size,
- T is the number of time frames,
- C is the number of microphone channels (optional),
- F is the number of frequency bins,
- 2 corresponds to the real and imaginary parts (optional if
input is a complex tensor).
ilens (torch.Tensor) – Input lengths with shape [Batch].
additional (Dict or None) –
Additional data included in the model. It can contain:
- “mode”: one of (“no_dereverb”, “dereverb”, “both”):
- ”no_dereverb”: Use only the first memory group for denoising without dereverberation.
- ”dereverb”: Use only the second memory group for denoising with dereverberation.
- ”both”: Use both memory groups for denoising with and without dereverberation.
Returns: List of tensors with shape [(B, T, F), …] for each speaker. ilens (torch.Tensor):
The input lengths tensor with shape (B,).
others (OrderedDict): : A dictionary containing predicted data, e.g., masks: OrderedDict[
’mask_spk1’: torch.Tensor(Batch, Frames, Freq), ‘mask_spk2’: torch.Tensor(Batch, Frames, Freq), … ‘mask_spkn’: torch.Tensor(Batch, Frames, Freq), <br/> ]
Return type: masked (List[Union[torch.Tensor, ComplexTensor]])
Examples
>>> separator = USESSeparator(input_dim=256)
>>> input_tensor = torch.randn(8, 100, 2, 256) # Example input
>>> ilens = torch.tensor([100] * 8) # Example input lengths
>>> outputs, lengths, additional_outputs = separator.forward(input_tensor, ilens)
>>> print(len(outputs)) # Number of separated sources
- Raises:
- ValueError – If the input shape is invalid or if an unknown mode is
- provided in additional. –
property num_spk