espnet2.enh.diffusion_enh.ESPnetDiffusionModel
espnet2.enh.diffusion_enh.ESPnetDiffusionModel
class espnet2.enh.diffusion_enh.ESPnetDiffusionModel(encoder: AbsEncoder, diffusion: AbsDiffusion, decoder: AbsDecoder, num_spk: int = 1, normalize: bool = False, **kwargs)
Bases: ESPnetEnhancementModel
Target Speaker Extraction Frontend model
Main entry of speech enhancement/separation model training.
- Parameters:
encoder β waveform encoder that converts waveforms to feature representations
separator β separator that enhance or separate the feature representations
decoder β waveform decoder that converts the feature back to waveforms
mask_module β mask module that converts the feature to masks NOTE: Only used for compatibility with joint speaker diarization. See test/espnet2/enh/test_espnet_enh_s2t_model.py for details.
loss_wrappers β list of loss wrappers Each loss wrapper contains a criterion for loss calculation and the corresonding loss weight. The losses will be calculated in the order of the list and summed up.
------------------------------------------------------------------
stft_consistency β (deprecated, kept for compatibility) whether to compute the TF-domain loss while enforcing STFT consistency NOTE: STFT consistency is now always used for frequency-domain spectrum losses.
loss_type β (deprecated, kept for compatibility) loss type
mask_type β (deprecated, kept for compatibility) mask type in TF-domain model
------------------------------------------------------------------
flexible_numspk β whether to allow the model to predict a variable number of speakers in its output. NOTE: This should be used when training a speech separation model for unknown number of speakers.
------------------------------------------------------------------
extract_feats_in_collect_stats β used in espnet2/tasks/abs_task.py for determining whether or not to skip model building in collect_stats stage (stage 5 in egs2/
*/enh1/enh.sh).
normalize_variance β whether to normalize the signal variance before model forward, and revert it back after.
normalize_variance_per_ch β whether to normalize the signal variance for each channel instead of the whole signal. NOTE: normalize_variance and normalize_variance_per_ch cannot be True at the same time.
------------------------------------------------------------------
categories β list of all possible categories of minibatches (order matters!) (e.g. [β1ch_8k_reverbβ, β1ch_8k_bothβ] for multi-condition training) NOTE: this will be used to convert category index to the corresponding name for logging in forward_loss. Different categories will have different loss name suffixes.
category_weights β list of weights for each category. Used to set loss weights for batches of different categories.
------------------------------------------------------------------
always_forward_in_48k β whether to always upsample the input speech to 48kHz for forward, and then downsample to the original sample rate for loss calculation. NOTE: this can be useful to train a model capable of handling various sampling rates while unifying bandwidth extension + speech enhancement.
collect_feats(speech_mix: Tensor, speech_mix_lengths: Tensor, **kwargs) β Dict[str, Tensor]
enhance(feature_mix)
forward(speech_mix: Tensor, speech_mix_lengths: Tensor = None, **kwargs) β Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calc loss
- Parameters:
- speech_mix β (Batch, samples) or (Batch, samples, channels)
- speech_ref1 β (Batch, samples) or (Batch, samples, channels)
- speech_ref2 β (Batch, samples) or (Batch, samples, channels)
- ...
- speech_mix_lengths β (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py
- enroll_ref1 β (Batch, samples_aux) enrollment (raw audio or embedding) for speaker 1
- enroll_ref2 β (Batch, samples_aux) enrollment (raw audio or embedding) for speaker 2
- ...
- kwargs β βutt_idβ is among the input.
forward_loss(speech_ref, speech_mix, speech_lengths) β Tuple[Tensor, Dict[str, Tensor], Tensor]
