espnet2.enh.separator.asteroid_models.AsteroidModel_Converter
espnet2.enh.separator.asteroid_models.AsteroidModel_Converter
class espnet2.enh.separator.asteroid_models.AsteroidModel_Converter(encoder_output_dim: int, model_name: str, num_spk: int, pretrained_path: str = '', loss_type: str = 'si_snr', **model_related_kwargs)
Bases: AbsSeparator
Convert models from Asteroid to AbsSeparator for speech separation.
This class serves as a bridge to utilize Asteroid models within the AbsSeparator framework. It allows for the instantiation of various Asteroid models, enabling speech enhancement tasks.
model
The instantiated Asteroid model.
_num_spk
The number of speakers to separate.
loss_type
The type of loss used for model training.
- Parameters:
- encoder_output_dim (int) – Input feature dimension, default=1 after the NullEncoder.
- model_name (str) – Asteroid model names, e.g., ConvTasNet, DPTNet. Refer to: https://github.com/asteroid-team/asteroid/ blob/master/asteroid/models/_init_.py
- num_spk (int) – Number of speakers to separate.
- pretrained_path (str , optional) – Name of the pretrained model from Asteroid in HF hub. Refer to: https://github.com/asteroid-team/asteroid/ blob/master/docs/source/readmes/pretrained_models.md and https://huggingface.co/models?filter=asteroid
- loss_type (str , optional) – Loss type of enhancement, default is “si_snr”.
- **model_related_kwargs – Additional arguments specific to each Asteroid model.
- Raises:
- AssertionError – If encoder_output_dim is not 1.
- ValueError – If an unsupported loss type is specified.
######### Examples
Instantiate the model with a pretrained ConvTasNet:
>>> model = AsteroidModel_Converter(
... model_name="ConvTasNet",
... encoder_output_dim=1,
... num_spk=2,
... loss_type="si_snr",
... pretrained_path="mpariente/ConvTasNet_WHAM!_sepclean"
... )
Process a mixture of audio signals:
>>> mixture = torch.randn(3, 16000)
>>> output, ilens, masks = model(mixture)
Access the number of speakers:
>>> num_speakers = model.num_spk
NOTE
Ensure that Asteroid is installed. Refer to: https://github.com/asteroid-team/asteroid
The class to convert the models from asteroid to AbsSeprator.
- Parameters:
- encoder_output_dim – input feature dimension, default=1 after the NullEncoder
- num_spk – number of speakers
- loss_type – loss type of enhancement
- model_name – Asteroid model names, e.g. ConvTasNet, DPTNet. Refers to https://github.com/asteroid-team/asteroid/ blob/master/asteroid/models/_init_.py
- pretrained_path – the name of pretrained model from Asteroid in HF hub. Refers to: https://github.com/asteroid-team/asteroid/ blob/master/docs/source/readmes/pretrained_models.md and https://huggingface.co/models?filter=asteroid
- model_related_kwargs – more args towards each specific asteroid model.
Perform the forward pass of the asteroid models.
This method takes raw waveform input and processes it through the model to estimate the source waveforms and additional output data such as masks for each speaker.
- Parameters:
- input (torch.Tensor) – Raw waveforms of shape [B, T] where B is the batch size and T is the length of the waveform.
- ilens (torch.Tensor) – Input lengths of shape [B]. This is optional and can be None.
- additional (Dict or None) – Additional data to be included in the model’s forward pass, if required.
- Returns: A tuple containing:
- estimated waveforms as a list of tensors [(B, T), …]
- input lengths as a tensor of shape (B,)
- additional predicted data (e.g., masks) as an OrderedDict with keys:
- ’mask_spk1’: torch.Tensor(Batch, T)
- ’mask_spk2’: torch.Tensor(Batch, T)
- …
- ’mask_spkn’: torch.Tensor(Batch, T)
- Return type: Tuple[List[torch.Tensor], torch.Tensor, OrderedDict]
######### Examples
>>> input_waveforms = torch.randn(3, 16000) # Batch of 3 waveforms
>>> model = AsteroidModel_Converter(
... model_name="ConvTasNet",
... encoder_output_dim=1,
... num_spk=2,
... loss_type="si_snr",
... pretrained_path="mpariente/ConvTasNet_WHAM!_sepclean",
... )
>>> estimated_waveforms, lengths, masks = model.forward(input_waveforms)
NOTE
The input tensor should have the correct shape, and the model should be properly initialized before calling this method.
- Raises:
- AssertionError – If the dimensions of the estimated sources do not
- match the number of speakers. –
forward_rawwav(input: Tensor, ilens: Tensor | None = None) → Tuple[Tensor, Tensor]
Output with waveforms.
This method processes the input raw waveforms through the model and returns the estimated source waveforms along with their respective lengths.
- Parameters:
- input (torch.Tensor) – Raw waveforms with shape [B, T], where B is the batch size and T is the length of the waveforms.
- ilens (torch.Tensor , optional) – Input lengths with shape [B]. Defaults to None.
- Returns: A tuple containing: : - estimated Waveforms (List[Union(torch.Tensor)]): A list of tensors representing the estimated waveforms for each speaker in the batch.
- ilens (torch.Tensor): A tensor representing the input lengths for the batch.
- Return type: Tuple[torch.Tensor, torch.Tensor]
######### Examples
>>> mixture = torch.randn(3, 16000) # Example input
>>> model = AsteroidModel_Converter(
... model_name="ConvTasNet",
... encoder_output_dim=1,
... num_spk=2,
... loss_type="si_snr",
... pretrained_path="mpariente/ConvTasNet_WHAM!_sepclean",
... )
>>> output, ilens = model.forward_rawwav(mixture, torch.tensor([16000]*3))
>>> print(output[0].shape) # Output shape for speaker 1
property num_spk