espnet2.enh.separator.asteroid_models.AsteroidModel_Converter
espnet2.enh.separator.asteroid_models.AsteroidModel_Converter
class espnet2.enh.separator.asteroid_models.AsteroidModel_Converter(encoder_output_dim: int, model_name: str, num_spk: int, pretrained_path: str = '', loss_type: str = 'si_snr', **model_related_kwargs)
Bases: AbsSeparator
The class to convert the models from asteroid to AbsSeprator.
- Parameters:
- encoder_output_dim β input feature dimension, default=1 after the NullEncoder
- num_spk β number of speakers
- loss_type β loss type of enhancement
- model_name β Asteroid model names, e.g. ConvTasNet, DPTNet. Refers to https://github.com/asteroid-team/asteroid/ blob/master/asteroid/models/_init_.py
- pretrained_path β the name of pretrained model from Asteroid in HF hub. Refers to: https://github.com/asteroid-team/asteroid/ blob/master/docs/source/readmes/pretrained_models.md and https://huggingface.co/models?filter=asteroid
- model_related_kwargs β more args towards each specific asteroid model.
forward(input: Tensor, ilens: Tensor = None, additional: Dict | None = None)
Whole forward of asteroid models.
Parameters:
- input (torch.Tensor) β Raw Waveforms [B, T]
- ilens (torch.Tensor) β input lengths [B]
- additional (Dict or None) β other data included in model
Returns: [(B, T), β¦] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[
βmask_spk1β: torch.Tensor(Batch, T), βmask_spk2β: torch.Tensor(Batch, T), β¦ βmask_spknβ: torch.Tensor(Batch, T),
]
Return type: estimated Waveforms(List[Union(torch.Tensor])
forward_rawwav(input: Tensor, ilens: Tensor = None) β Tuple[Tensor, Tensor]
Output with waveforms.
property num_spk
