espnet2.enh.espnet_model.ESPnetEnhancementModel
espnet2.enh.espnet_model.ESPnetEnhancementModel
class espnet2.enh.espnet_model.ESPnetEnhancementModel(encoder: AbsEncoder, separator: AbsSeparator | None, decoder: AbsDecoder, mask_module: AbsMask | None, loss_wrappers: List[AbsLossWrapper] | None, stft_consistency: bool = False, loss_type: str = 'mask_mse', mask_type: str | None = None, flexible_numspk: bool = False, extract_feats_in_collect_stats: bool = False, normalize_variance: bool = False, normalize_variance_per_ch: bool = False, categories: list = [], category_weights: list = [], always_forward_in_48k: bool = False)
Bases: AbsESPnetModel
Speech enhancement or separation Frontend model
Main entry of speech enhancement/separation model training.
- Parameters:
encoder β waveform encoder that converts waveforms to feature representations
separator β separator that enhance or separate the feature representations
decoder β waveform decoder that converts the feature back to waveforms
mask_module β mask module that converts the feature to masks NOTE: Only used for compatibility with joint speaker diarization. See test/espnet2/enh/test_espnet_enh_s2t_model.py for details.
loss_wrappers β list of loss wrappers Each loss wrapper contains a criterion for loss calculation and the corresonding loss weight. The losses will be calculated in the order of the list and summed up.
------------------------------------------------------------------
stft_consistency β (deprecated, kept for compatibility) whether to compute the TF-domain loss while enforcing STFT consistency NOTE: STFT consistency is now always used for frequency-domain spectrum losses.
loss_type β (deprecated, kept for compatibility) loss type
mask_type β (deprecated, kept for compatibility) mask type in TF-domain model
------------------------------------------------------------------
flexible_numspk β whether to allow the model to predict a variable number of speakers in its output. NOTE: This should be used when training a speech separation model for unknown number of speakers.
------------------------------------------------------------------
extract_feats_in_collect_stats β used in espnet2/tasks/abs_task.py for determining whether or not to skip model building in collect_stats stage (stage 5 in egs2/
*/enh1/enh.sh).
normalize_variance β whether to normalize the signal variance before model forward, and revert it back after.
normalize_variance_per_ch β whether to normalize the signal variance for each channel instead of the whole signal. NOTE: normalize_variance and normalize_variance_per_ch cannot be True at the same time.
------------------------------------------------------------------
categories β list of all possible categories of minibatches (order matters!) (e.g. [β1ch_8k_reverbβ, β1ch_8k_bothβ] for multi-condition training) NOTE: this will be used to convert category index to the corresponding name for logging in forward_loss. Different categories will have different loss name suffixes.
category_weights β list of weights for each category. Used to set loss weights for batches of different categories.
------------------------------------------------------------------
always_forward_in_48k β whether to always upsample the input speech to 48kHz for forward, and then downsample to the original sample rate for loss calculation. NOTE: this can be useful to train a model capable of handling various sampling rates while unifying bandwidth extension + speech enhancement.
collect_feats(speech_mix: Tensor, speech_mix_lengths: Tensor, **kwargs) β Dict[str, Tensor]
forward(speech_mix: Tensor, speech_mix_lengths: Tensor = None, **kwargs) β Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calc loss
- Parameters:
- speech_mix β (Batch, samples) or (Batch, samples, channels)
- speech_ref β (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels)
- speech_mix_lengths β (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py
- kwargs β βutt_idβ is among the input.
forward_enhance(speech_mix: Tensor, speech_lengths: Tensor, additional: Dict | None = None, fs: int | None = None) β Tuple[Tensor, Tensor, Tensor]
forward_loss(speech_pre: Tensor, speech_lengths: Tensor, feature_mix: Tensor, feature_pre: List[Tensor], others: OrderedDict, speech_ref: List[Tensor], noise_ref: List[Tensor] | None = None, dereverb_speech_ref: List[Tensor] | None = None, category: Tensor | None = None, num_spk: int | None = None, fs: int | None = None) β Tuple[Tensor, Dict[str, Tensor], Tensor]
static sort_by_perm(nn_output, perm)
Sort the input list of tensors by the specified permutation.
- Parameters:
- nn_output β List[torch.Tensor(Batch, β¦)], len(nn_output) == num_spk
- perm β (Batch, num_spk) or List[torch.Tensor(num_spk)]
- Returns: List[torch.Tensor(Batch, β¦)]
- Return type: nn_output_new
