espnet2.enh.diffusion_enh.ESPnetDiffusionModel
espnet2.enh.diffusion_enh.ESPnetDiffusionModel
class espnet2.enh.diffusion_enh.ESPnetDiffusionModel(encoder: AbsEncoder, diffusion: AbsDiffusion, decoder: AbsDecoder, num_spk: int = 1, normalize: bool = False, **kwargs)
Bases: ESPnetEnhancementModel
Target Speaker Extraction Frontend model.
This model implements a frontend for target speaker extraction, which is designed to enhance audio signals by extracting a specific speaker’s voice from a mixture of sounds. It combines an encoder, a diffusion process, and a decoder to achieve the desired enhancement.
encoder
The encoder module for processing audio input.
- Type:AbsEncoder
diffusion
The diffusion module for signal enhancement.
- Type:AbsDiffusion
decoder
The decoder module for reconstructing the output.
- Type:AbsDecoder
num_spk
The number of speakers to enhance (default is 1).
- Type: int
normalize
A flag indicating whether to normalize the input signals.
Type: bool
Parameters:
- encoder (AbsEncoder) – The encoder instance to be used.
- diffusion (AbsDiffusion) – The diffusion instance for the enhancement.
- decoder (AbsDecoder) – The decoder instance to be used.
- num_spk (int) – Number of speakers (default is 1).
- normalize (bool) – Flag to indicate normalization of input (default is False).
- **kwargs – Additional keyword arguments for the parent class.
Raises:AssertionError – If num_spk is not equal to 1, as only enhancement models are currently supported.
############# Examples
>>> model = ESPnetDiffusionModel(encoder, diffusion, decoder, num_spk=1)
>>> loss, stats, weight = model.forward(speech_mix, speech_mix_lengths,
... speech_ref1=speech_ref1)
######## NOTE
- This model is currently limited to enhancement tasks with a single target speaker.
- The input to the forward method requires at least one reference speech signal.
Main entry of speech enhancement/separation model training.
- Parameters:
encoder – waveform encoder that converts waveforms to feature representations
separator – separator that enhance or separate the feature representations
decoder – waveform decoder that converts the feature back to waveforms
mask_module – mask module that converts the feature to masks NOTE: Only used for compatibility with joint speaker diarization. See test/espnet2/enh/test_espnet_enh_s2t_model.py for details.
loss_wrappers – list of loss wrappers Each loss wrapper contains a criterion for loss calculation and the corresonding loss weight. The losses will be calculated in the order of the list and summed up.
------------------------------------------------------------------
stft_consistency – (deprecated, kept for compatibility) whether to compute the TF-domain loss while enforcing STFT consistency NOTE: STFT consistency is now always used for frequency-domain spectrum losses.
loss_type – (deprecated, kept for compatibility) loss type
mask_type – (deprecated, kept for compatibility) mask type in TF-domain model
------------------------------------------------------------------
flexible_numspk – whether to allow the model to predict a variable number of speakers in its output. NOTE: This should be used when training a speech separation model for unknown number of speakers.
------------------------------------------------------------------
extract_feats_in_collect_stats – used in espnet2/tasks/abs_task.py for determining whether or not to skip model building in collect_stats stage (stage 5 in egs2/
*
/enh1/enh.sh).
normalize_variance – whether to normalize the signal variance before model forward, and revert it back after.
normalize_variance_per_ch – whether to normalize the signal variance for each channel instead of the whole signal. NOTE: normalize_variance and normalize_variance_per_ch cannot be True at the same time.
------------------------------------------------------------------
categories – list of all possible categories of minibatches (order matters!) (e.g. [“1ch_8k_reverb”, “1ch_8k_both”] for multi-condition training) NOTE: this will be used to convert category index to the corresponding name for logging in forward_loss. Different categories will have different loss name suffixes.
category_weights – list of weights for each category. Used to set loss weights for batches of different categories.
------------------------------------------------------------------
always_forward_in_48k – whether to always upsample the input speech to 48kHz for forward, and then downsample to the original sample rate for loss calculation. NOTE: this can be useful to train a model capable of handling various sampling rates while unifying bandwidth extension + speech enhancement.
collect_feats(speech_mix: Tensor, speech_mix_lengths: Tensor, **kwargs) → Dict[str, Tensor]
Collect features from the mixed speech input.
This method processes the mixed speech input tensor and its corresponding lengths to return a dictionary containing the features and their lengths. It is typically used in data-parallel scenarios to prepare input for further processing.
- Parameters:
- speech_mix – A tensor of shape (Batch, samples) or (Batch, samples, channels) representing the mixed speech signals.
- speech_mix_lengths – A tensor of shape (Batch,) containing the lengths of each mixed speech signal.
- kwargs – Additional keyword arguments.
- Returns:
- “feats”: The processed speech mix tensor.
- ”feats_lengths”: The lengths of the processed speech mix.
- Return type: A dictionary containing
############# Examples
>>> speech_mix = torch.randn(4, 16000) # 4 samples, 16000 time steps
>>> speech_mix_lengths = torch.tensor([16000, 16000, 16000, 16000])
>>> model = ESPnetDiffusionModel(...)
>>> features = model.collect_feats(speech_mix, speech_mix_lengths)
>>> print(features["feats"].shape) # Output: torch.Size([4, 16000])
>>> print(features["feats_lengths"]) # Output: tensor([16000, 16000, 16000, 16000])
enhance(feature_mix)
Enhancement model module for speaker extraction using diffusion processes.
This module defines the ESPnetDiffusionModel class, which is designed for target speaker extraction through a combination of an encoder, a diffusion process, and a decoder. It normalizes input features if specified and computes loss during the forward pass.
encoder
The encoder component of the model.
- Type:AbsEncoder
diffusion
The diffusion process component of the model.
- Type:AbsDiffusion
decoder
The decoder component of the model.
- Type:AbsDecoder
num_spk
The number of speakers (default is 1).
- Type: int
normalize
Flag to indicate whether to normalize the input features (default is False).
Type: bool
Parameters:
- encoder (AbsEncoder) – The encoder instance to process the input.
- diffusion (AbsDiffusion) – The diffusion instance for enhancing features.
- decoder (AbsDecoder) – The decoder instance to reconstruct audio signals.
- num_spk (int , optional) – Number of speakers (default is 1).
- normalize (bool , optional) – Flag for normalizing input features (default is False).
- **kwargs – Additional keyword arguments.
Returns: A tuple containing the computed loss, statistics, and weight for backpropagation.
Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
Raises:
- AssertionError – If num_spk is not equal to 1, as only enhancement
- models are currently supported. –
############# Examples
Create model components
encoder = MyEncoder(…) diffusion = MyDiffusion(…) decoder = MyDecoder(…)
Instantiate the model
model = ESPnetDiffusionModel(encoder, diffusion, decoder, num_spk=1)
Forward pass with mixed speech and reference signals
loss, stats, weight = model.forward(speech_mix, speech_mix_lengths,
speech_ref1=speech_ref1)
Enhance features
enhanced_features = model.enhance(feature_mix)
######## NOTE The model currently only supports single-speaker enhancement tasks.
forward(speech_mix: Tensor, speech_mix_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calculate loss.
This method processes the mixed speech input through the encoder and decoder to compute the loss and other statistics. It also handles normalization of the input signals if specified.
- Parameters:
- speech_mix – A tensor of shape (Batch, samples) or (Batch, samples, channels) representing the mixed speech.
- speech_mix_lengths – A tensor of shape (Batch,) indicating the lengths of the mixed speech signals. Default is None, which is suitable for chunk iterators that do not return the speech lengths. Refer to espnet2/iterators/chunk_iter_factory.py for details.
- kwargs – Additional keyword arguments. It must include at least “speech_ref1” for the first reference signal. Other reference signals can be provided as “speech_ref2”, etc. Enrollment references can also be passed as “enroll_ref1”, “enroll_ref2”, etc.
- Raises:AssertionError – If “speech_ref1” is not provided in kwargs, or if the dimensions of input tensors do not match.
- Returns:
- loss: A tensor representing the computed loss.
- stats: A dictionary containing various statistics related to : the forward pass.
- weight: A tensor representing the weight for the computed loss.
- Return type: A tuple containing
############# Examples
>>> model = ESPnetDiffusionModel(...)
>>> speech_mix = torch.randn(2, 16000) # Batch of 2, 1 second audio
>>> speech_ref1 = torch.randn(2, 16000) # Reference for speaker 1
>>> loss, stats, weight = model.forward(
... speech_mix,
... speech_ref1=speech_ref1,
... speech_mix_lengths=torch.tensor([16000, 16000])
... )
######## NOTE The method assumes that the batch size of speech_mix and speech_ref tensors are the same.
forward_loss(speech_ref, speech_mix, speech_lengths) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Compute the forward loss for the speech enhancement model.
This method processes the mixed speech and reference signals through the encoder, computes the diffusion-based loss, and returns the loss along with relevant statistics and weight for the batch.
- Parameters:
- speech_ref (List *[*torch.Tensor ]) – A list of reference speech signals for each speaker. Each tensor should be of shape (Batch, samples).
- speech_mix (torch.Tensor) – The mixed speech signal of shape (Batch, samples) or (Batch, samples, channels).
- speech_lengths (torch.Tensor) – A tensor of shape (Batch,) containing the lengths of the mixed speech signals.
- Returns: A tuple containing:
- loss (torch.Tensor): The computed loss for the batch.
- stats (Dict[str, torch.Tensor]): A dictionary containing statistics related to the loss, such as ‘loss’.
- weight (torch.Tensor): The weight tensor for the batch.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
############# Examples
>>> model = ESPnetDiffusionModel(...)
>>> speech_mix = torch.randn(8, 16000) # Example mixed speech
>>> speech_ref = [torch.randn(8, 16000)] # Example reference
>>> speech_lengths = torch.tensor([16000] * 8) # Lengths
>>> loss, stats, weight = model.forward_loss(speech_ref, speech_mix, speech_lengths)
######## NOTE Ensure that the input reference signals are correctly shaped and that at least one reference signal is provided.
- Raises:
- AssertionError – If the dimensions of the input tensors do not match
- or if no reference signal is provided. –