espnet2.enh.espnet_model.ESPnetEnhancementModel
espnet2.enh.espnet_model.ESPnetEnhancementModel
class espnet2.enh.espnet_model.ESPnetEnhancementModel(encoder: AbsEncoder, separator: AbsSeparator | None, decoder: AbsDecoder, mask_module: AbsMask | None, loss_wrappers: List[AbsLossWrapper] | None, stft_consistency: bool = False, loss_type: str = 'mask_mse', mask_type: str | None = None, flexible_numspk: bool = False, extract_feats_in_collect_stats: bool = False, normalize_variance: bool = False, normalize_variance_per_ch: bool = False, categories: list = [], category_weights: list = [], always_forward_in_48k: bool = False)
Bases: AbsESPnetModel
Speech enhancement or separation Frontend model.
This class implements a speech enhancement or separation model that utilizes various components such as an encoder, separator, decoder, and loss wrappers. It is designed for training models that enhance or separate audio signals in different conditions.
encoder
The waveform encoder to convert waveforms to feature representations.
- Type:AbsEncoder
separator
The separator that enhances or separates the feature representations.
- Type: Optional[AbsSeparator]
decoder
The waveform decoder to convert features back to waveforms.
- Type:AbsDecoder
mask_module
The mask module for converting features to masks, used for compatibility with joint speaker diarization.
- Type: Optional[AbsMask]
loss_wrappers
A list of loss wrappers that contain criteria for loss calculation and corresponding weights.
- Type: Optional[List[AbsLossWrapper]]
flexible_numspk
If True, allows the model to predict a variable number of speakers in its output.
Type: bool
Parameters:
- encoder – waveform encoder that converts waveforms to feature representations.
- separator – separator that enhances or separates the feature representations.
- decoder – waveform decoder that converts the feature back to waveforms.
- mask_module – mask module that converts the feature to masks. NOTE: Only used for compatibility with joint speaker diarization.
- loss_wrappers – list of loss wrappers, each containing a criterion for loss calculation and the corresponding loss weight. The losses will be calculated in the order of the list and summed up.
- stft_consistency – (deprecated, kept for compatibility) whether to compute the TF-domain loss while enforcing STFT consistency. NOTE: STFT consistency is now always used for frequency-domain spectrum losses.
- loss_type – (deprecated, kept for compatibility) loss type.
- mask_type – (deprecated, kept for compatibility) mask type in TF-domain model.
- flexible_numspk – whether to allow the model to predict a variable number of speakers in its output.
- extract_feats_in_collect_stats – used for determining whether to skip model building in the collect_stats stage.
- normalize_variance – whether to normalize the signal variance before model forward, and revert it back after.
- normalize_variance_per_ch – whether to normalize the signal variance for each channel instead of the whole signal. NOTE: normalize_variance and normalize_variance_per_ch cannot be True at the same time.
- categories – list of all possible categories of minibatches (order matters!). NOTE: this will be used to convert category index to the corresponding name for logging in forward_loss.
- category_weights – list of weights for each category, used to set loss weights for batches of different categories.
- always_forward_in_48k – whether to always upsample the input speech to 48kHz for forward, and then downsample to the original sample rate for loss calculation. NOTE: this can be useful to train a model capable of handling various sampling rates while unifying bandwidth extension and speech enhancement.
############### Examples
>>> model = ESPnetEnhancementModel(encoder, separator, decoder)
>>> loss, stats, weight = model.forward(speech_mix, speech_ref=speech_ref)
######## NOTE Ensure that the loss_wrappers and other components are correctly instantiated to avoid runtime errors during training and inference.
Main entry of speech enhancement/separation model training.
- Parameters:
encoder – waveform encoder that converts waveforms to feature representations
separator – separator that enhance or separate the feature representations
decoder – waveform decoder that converts the feature back to waveforms
mask_module – mask module that converts the feature to masks NOTE: Only used for compatibility with joint speaker diarization. See test/espnet2/enh/test_espnet_enh_s2t_model.py for details.
loss_wrappers – list of loss wrappers Each loss wrapper contains a criterion for loss calculation and the corresonding loss weight. The losses will be calculated in the order of the list and summed up.
------------------------------------------------------------------
stft_consistency – (deprecated, kept for compatibility) whether to compute the TF-domain loss while enforcing STFT consistency NOTE: STFT consistency is now always used for frequency-domain spectrum losses.
loss_type – (deprecated, kept for compatibility) loss type
mask_type – (deprecated, kept for compatibility) mask type in TF-domain model
------------------------------------------------------------------
flexible_numspk – whether to allow the model to predict a variable number of speakers in its output. NOTE: This should be used when training a speech separation model for unknown number of speakers.
------------------------------------------------------------------
extract_feats_in_collect_stats – used in espnet2/tasks/abs_task.py for determining whether or not to skip model building in collect_stats stage (stage 5 in egs2/
*
/enh1/enh.sh).
normalize_variance – whether to normalize the signal variance before model forward, and revert it back after.
normalize_variance_per_ch – whether to normalize the signal variance for each channel instead of the whole signal. NOTE: normalize_variance and normalize_variance_per_ch cannot be True at the same time.
------------------------------------------------------------------
categories – list of all possible categories of minibatches (order matters!) (e.g. [“1ch_8k_reverb”, “1ch_8k_both”] for multi-condition training) NOTE: this will be used to convert category index to the corresponding name for logging in forward_loss. Different categories will have different loss name suffixes.
category_weights – list of weights for each category. Used to set loss weights for batches of different categories.
------------------------------------------------------------------
always_forward_in_48k – whether to always upsample the input speech to 48kHz for forward, and then downsample to the original sample rate for loss calculation. NOTE: this can be useful to train a model capable of handling various sampling rates while unifying bandwidth extension + speech enhancement.
collect_feats(speech_mix: Tensor, speech_mix_lengths: Tensor, **kwargs) → Dict[str, Tensor]
Collect features from the mixed speech input.
This method extracts features from the input speech mixture and its corresponding lengths, preparing them for further processing in the model. It is typically used for gathering features when needed during data-parallel training.
- Parameters:
- speech_mix – A tensor representing the mixed speech input of shape (Batch, samples) or (Batch, samples, channels).
- speech_mix_lengths – A tensor of shape (Batch,) representing the lengths of the mixed speech input.
- Returns:
- feats: The collected features, same as speech_mix.
- feats_lengths: The lengths of the collected features, same as speech_mix_lengths.
- Return type: A dictionary containing
############### Examples
>>> model = ESPnetEnhancementModel(...)
>>> mixed_speech = torch.randn(10, 16000) # 10 samples, 16000 points
>>> lengths = torch.tensor([16000] * 10) # all lengths are 16000
>>> features = model.collect_feats(mixed_speech, lengths)
>>> print(features['feats'].shape)
torch.Size([10, 16000])
>>> print(features['feats_lengths'])
tensor([16000, 16000, 16000, 16000, 16000, 16000, 16000, 16000, 16000, 16000])
forward(speech_mix: Tensor, speech_mix_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calc loss.
This method processes the mixed speech signal through the encoder, separator, and decoder, and computes the loss based on the predicted output and reference signals.
- Parameters:
- speech_mix – Tensor of shape (Batch, samples) or (Batch, samples, channels) representing the mixed speech input.
- speech_mix_lengths – Tensor of shape (Batch,), default is None for chunk iterator, as the chunk-iterator does not return the speech lengths. See in espnet2/iterators/chunk_iter_factory.py.
- kwargs – Additional keyword arguments; “utt_id” is among the inputs. It should include reference speech signals in the form of speech_ref1, speech_ref2, etc.
- Returns:
- loss: Computed loss value as a tensor.
- stats: A dictionary of statistics from the loss computation.
- weight: A tensor representing the weight for the loss.
- Return type: Tuple containing
- Raises:AssertionError – If no reference speech signals are provided in kwargs.
############### Examples
>>> model = ESPnetEnhancementModel(...)
>>> speech_mix = torch.randn(2, 16000) # Example mixed speech
>>> speech_mix_lengths = torch.tensor([16000, 16000]) # Lengths
>>> speech_ref1 = torch.randn(2, 16000) # Reference signal for speaker 1
>>> loss, stats, weight = model.forward(
speech_mix,
speech_mix_lengths,
speech_ref1=speech_ref1
)
######## NOTE Ensure that the input speech mix and reference signals are properly shaped and that the number of reference signals matches the expected number of speakers in the model.
forward_enhance(speech_mix: Tensor, speech_lengths: Tensor, additional: Dict | None = None, fs: int | None = None) → Tuple[Tensor, Tensor, Tensor]
Enhances the input speech mixture by applying the encoder, separator, and decoder, returning the enhanced speech along with the features.
- Parameters:
- speech_mix – A tensor of shape (Batch, samples) or (Batch, samples, channels) representing the input speech mixture.
- speech_lengths – A tensor of shape (Batch,) representing the lengths of the input speech mixtures.
- additional – An optional dictionary containing additional information required for processing, such as the number of speakers.
- fs – An optional integer representing the sampling frequency of the input signal.
- Returns:
- speech_pre: A tensor representing the enhanced speech output.
- feature_mix: A tensor representing the features of the input speech mixture.
- feature_pre: A tensor representing the features of the enhanced speech.
- Return type: A tuple containing
############### Examples
>>> model = ESPnetEnhancementModel(...)
>>> speech_mix = torch.randn(8, 16000) # Example batch of 8
>>> speech_lengths = torch.tensor([16000] * 8) # All lengths are 16000
>>> enhanced_speech, feature_mix, feature_pre = model.forward_enhance(
... speech_mix, speech_lengths
... )
forward_loss(speech_pre: Tensor, speech_lengths: Tensor, feature_mix: Tensor, feature_pre: List[Tensor], others: OrderedDict, speech_ref: List[Tensor], noise_ref: List[Tensor] | None = None, dereverb_speech_ref: List[Tensor] | None = None, category: Tensor | None = None, num_spk: int | None = None, fs: int | None = None) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Calculates the loss for the enhanced speech outputs.
This method computes the loss between the predicted enhanced speech signals and the reference signals using various criteria defined in the loss wrappers. It supports calculating losses for noise and dereverberated signals if provided.
- Parameters:
- speech_pre – The predicted enhanced speech signals of shape (Batch, num_speaker, samples) or (Batch, samples).
- speech_lengths – A tensor containing the lengths of the input speech signals, shape (Batch,).
- feature_mix – The feature representation of the mixed speech signals.
- feature_pre – A list containing the features of the predicted enhanced signals.
- others – An ordered dictionary containing additional information required for loss computation.
- speech_ref – A list of reference speech signals for each speaker, shape (Batch, num_speaker, samples).
- noise_ref – An optional list of reference noise signals, shape (Batch, num_noise_type, samples).
- dereverb_speech_ref – An optional list of dereverberated reference speech signals, shape (Batch, num_speaker, samples).
- category – An optional tensor indicating the category of the input batch.
- num_spk – An optional integer specifying the number of speakers (if not using the model’s internal count).
- fs – An optional integer specifying the sampling frequency.
- Returns:
- loss: A tensor representing the total computed loss.
- stats: A dictionary containing statistics for the computed loss.
- weight: A tensor representing the weight of the loss.
- Return type: A tuple containing
- Raises:
- ValueError – If there are missing reference signals for loss computation or if the noise or dereverberated references are required but not provided.
- AttributeError – If the loss tensor does not require gradients during training.
############### Examples
>>> loss, stats, weight = model.forward_loss(
... speech_pre, speech_lengths, feature_mix, feature_pre,
... others, speech_ref, noise_ref, dereverb_speech_ref,
... category, num_spk, fs
... )
######## NOTE Ensure that the provided references match the expected input dimensions for loss computation. The method also handles categorization and weights for the losses based on the input category.
static sort_by_perm(nn_output, perm)
Sort the input list of tensors by the specified permutation.
This method takes a list of tensors (typically outputs from a neural network) and reorders them according to a provided permutation. This is useful in scenarios where the output tensors need to be matched to a specific order after processing, such as in speech enhancement or separation tasks.
- Parameters:
- nn_output – List[torch.Tensor(Batch, …)], length of nn_output must equal num_spk. Each tensor in the list corresponds to the output for a specific speaker.
- perm – (Batch, num_spk) or List[torch.Tensor(num_spk)] This specifies the permutation indices to reorder the outputs.
- Returns: List[torch.Tensor(Batch, …)] : The reordered list of tensors according to the specified permutation.
- Return type: nn_output_new
############### Examples
>>> import torch
>>> output = [torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[5, 6], [7, 8]])]
>>> perm = torch.tensor([[1, 0], [0, 1]]) # swap speakers
>>> sorted_output = ESPnetEnhancementModel.sort_by_perm(output, perm)
>>> for tensor in sorted_output:
... print(tensor)
tensor([[5, 6],
[3, 4]])
tensor([[1, 2],
[7, 8]])
######## NOTE If nn_output contains only one tensor, it is returned as is, without any sorting.
- Raises:AssertionError – If the size of nn_output and perm are not compatible.