espnet2.enh.espnet_model_tse.ESPnetExtractionModel
espnet2.enh.espnet_model_tse.ESPnetExtractionModel
class espnet2.enh.espnet_model_tse.ESPnetExtractionModel(encoder: AbsEncoder, extractor: AbsExtractor, decoder: AbsDecoder, loss_wrappers: List[AbsLossWrapper], num_spk: int = 1, flexible_numspk: bool = False, share_encoder: bool = True, extract_feats_in_collect_stats: bool = False)
Bases: AbsESPnetModel
ESPnetExtractionModel is a target speaker extraction frontend model.
This model integrates an encoder, extractor, and decoder to perform target speaker extraction from mixed audio signals. It supports multiple loss wrappers and can handle a flexible number of speakers.
encoder
The encoder used for processing input audio.
- Type:AbsEncoder
extractor
The extractor used to separate the target speakers from the mixture.
- Type:AbsExtractor
decoder
The decoder that reconstructs the separated audio.
- Type:AbsDecoder
loss_wrappers
A list of loss wrappers for loss computation.
- Type: List[AbsLossWrapper]
num_spk
The number of target speakers to extract (default: 1).
- Type: int
flexible_numspk
If True, num_spk is regarded as the maximum possible number of speakers (default: False).
- Type: bool
share_encoder
Whether to share the encoder for both mixture and enrollment (default: True).
- Type: bool
extract_feats_in_collect_stats
If True, features are extracted during statistics collection (default: False).
- Type: bool
ref_channel
The reference channel for multi-channel signals.
Type: int
Parameters:
- encoder (AbsEncoder) – The encoder to use for the model.
- extractor (AbsExtractor) – The extractor to use for the model.
- decoder (AbsDecoder) – The decoder to use for the model.
- loss_wrappers (List [AbsLossWrapper ]) – List of loss wrappers for training.
- num_spk (int , optional) – Number of target speakers (default: 1).
- flexible_numspk (bool , optional) – Allow flexible number of speakers (default: False).
- share_encoder (bool , optional) – Share encoder between mixture and enrollment (default: True).
- extract_feats_in_collect_stats (bool , optional) – Extract features during collect stats (default: False).
Raises:ValueError – If there are duplicated loss names or unsupported loss types.
############# Examples
Example usage:
encoder = SomeEncoder() extractor = SomeExtractor() decoder = SomeDecoder() loss_wrapper = SomeLossWrapper()
model = ESPnetExtractionModel( : encoder=encoder, extractor=extractor, decoder=decoder, loss_wrappers=[loss_wrapper], num_spk=2, flexible_numspk=True
)
Forward pass
speech_mix = torch.randn(4, 16000) # (Batch, Samples) speech_lengths = torch.tensor([16000] * 4) # Lengths for each batch speech_ref1 = torch.randn(4, 16000) # Reference for speaker 1 speech_ref2 = torch.randn(4, 16000) # Reference for speaker 2
loss, stats, weight = model.forward( : speech_mix, speech_lengths=speech_lengths, speech_ref1=speech_ref1, speech_ref2=speech_ref2, enroll_ref1=speech_ref1, enroll_ref2=speech_ref2
)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(speech_mix: Tensor, speech_mix_lengths: Tensor, **kwargs) → Dict[str, Tensor]
Collect features from the input speech mixture for data-parallel processing.
This method extracts features from the input speech mixture tensor and its lengths. It prepares the data for further processing, ensuring that the speech mixture is correctly sized for data-parallel operations.
- Parameters:
- speech_mix – A tensor of shape (Batch, samples) or (Batch, samples, channels) representing the input speech mixture.
- speech_mix_lengths – A tensor of shape (Batch,) representing the lengths of the input speech mixture for each batch item.
- **kwargs – Additional keyword arguments.
- Returns:
- “feats”: A tensor containing the extracted features.
- ”feats_lengths”: A tensor containing the lengths of the features.
- Return type: A dictionary containing
############# Examples
>>> speech_mix = torch.randn(4, 16000) # Example with 4 batches of 1s audio
>>> speech_mix_lengths = torch.tensor([16000, 16000, 16000, 16000])
>>> model = ESPnetExtractionModel(...)
>>> features = model.collect_feats(speech_mix, speech_mix_lengths)
>>> print(features["feats"].shape) # Should be (4, 16000)
>>> print(features["feats_lengths"]) # Should be tensor of lengths
######## NOTE This method is particularly useful in scenarios where data-parallel processing is required, ensuring that the features are gathered and aligned correctly.
forward(speech_mix: Tensor, speech_mix_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calculate loss.
This method processes the input speech mixture through the frontend, encoder, and decoder, and computes the loss based on the reference speech signals and enrollment references provided in kwargs.
- Parameters:
speech_mix – A tensor of shape (Batch, samples) or (Batch, samples, channels) representing the mixed speech signals.
speech_mix_lengths – A tensor of shape (Batch,) representing the lengths of the mixed speech signals. Defaults to None, which is used for chunk iterators that do not return speech lengths (see espnet2/iterators/chunk_iter_factory.py).
kwargs –
Additional keyword arguments. It must include:
- “speech_ref1”: (Batch, samples) or
(Batch, samples, channels) for the reference signal of speaker 1.
- ”enroll_ref1”: (Batch, samples_aux) for enrollment (raw audio or embedding) for speaker 1.
- Additional enrollment references can be included as “speech_ref2”, “enroll_ref2”, etc.
- Returns:
- loss: A tensor representing the computed loss.
- stats: A dictionary containing various statistics.
- weight: A tensor representing the weight for the loss.
- Return type: A tuple containing
- Raises:AssertionError – If the required reference signals are not provided in kwargs or if their shapes are inconsistent.
############# Examples
>>> model = ESPnetExtractionModel(...)
>>> speech_mix = torch.randn(4, 16000) # Example mixed speech
>>> speech_ref1 = torch.randn(4, 16000) # Example reference for speaker 1
>>> enroll_ref1 = torch.randn(4, 8000) # Example enrollment for speaker 1
>>> loss, stats, weight = model.forward(
... speech_mix,
... speech_ref1=speech_ref1,
... enroll_ref1=enroll_ref1
... )
######## NOTE The method expects at least one reference signal and one enrollment signal to be provided in kwargs. The number of reference signals must match the number of speakers defined in the model.
forward_enhance(speech_mix: Tensor, speech_lengths: Tensor, enroll_ref: Tensor, enroll_ref_lengths: Tensor, additional: Dict | None = None) → Tuple[Tensor, Tensor, Tensor]
Enhances the input mixed speech signal using the encoder and extractor.
This method processes the mixed speech input and reference signals for enrollment to produce enhanced speech outputs. It uses the encoder to extract features from the mixed speech and reference signals, which are then processed by the extractor to generate the enhanced signals.
- Parameters:
- speech_mix – Tensor of shape (Batch, samples) or (Batch, samples, channels) representing the mixed speech.
- speech_lengths – Tensor of shape (Batch,) indicating the lengths of the mixed speech signals.
- enroll_ref – Tensor of shape (Batch, samples_aux) or (Batch, samples_aux, channels) representing the enrollment reference signals for each speaker.
- enroll_ref_lengths – Tensor of shape (Batch,) indicating the lengths of the enrollment reference signals.
- additional – Optional dictionary containing additional parameters for enhancement. Default is None.
- Returns:
- speech_pre: Enhanced speech tensor of shape (Batch, samples) : or (Batch, samples, channels).
- feature_mix: Features extracted from the mixed speech.
- feature_pre: Features extracted from the enhanced speech.
- Return type: Tuple containing
############# Examples
>>> model = ESPnetExtractionModel(...)
>>> enhanced_speech, features_mix, features_pre = model.forward_enhance(
... speech_mix, speech_lengths, enroll_ref, enroll_ref_lengths
... )
######## NOTE This method is designed to work with both single and multi-channel signals. The extraction of features and enhancement is based on the provided enrollment references for the target speakers.
- Raises:ValueError – If the input dimensions do not match the expected shapes or if the reference signals are not provided as required.
forward_loss(speech_pre: Tensor, speech_lengths: Tensor, feature_mix: Tensor, feature_pre: Tensor, others: OrderedDict, speech_ref: Tensor) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Calculates the forward loss for the target speaker extraction model.
This method computes the loss between the predicted speech signals and the reference speech signals using specified loss criteria. It supports both time-domain and frequency-domain losses. The method aggregates loss values from multiple loss wrappers and returns the overall loss, along with additional statistics.
Parameters:
- speech_pre – (Batch, samples) or (Batch, samples, channels) - The predicted speech signals from the model.
- speech_lengths – (Batch,) - A tensor indicating the lengths of the predicted speech signals.
- feature_mix – (Batch, feature_dim, samples) - The mixed speech features extracted from the input mixed signals.
- feature_pre – (Batch, num_speakers, feature_dim, samples) - The features of the predicted speech signals.
- others – OrderedDict - Additional data required for loss computation, such as masks or other auxiliary information.
- speech_ref – (Batch, num_speakers, samples) - The reference speech signals for each target speaker.
Returns:
- loss: The computed loss value.
- stats: A dictionary containing additional statistics from the loss
computation.
- weight: The weight used for the loss computation.
Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
Raises:
- NotImplementedError – If an unsupported loss type is encountered in the loss wrappers.
- AttributeError – If all criteria have only_for_test=True during training.
############# Examples
loss, stats, weight, perm = model.forward_loss( : speech_pre, speech_lengths, feature_mix, feature_pre, others, speech_ref
)
######## NOTE This method is designed to work with multiple loss wrappers that are defined during model initialization. Ensure that the criteria used are compatible with the input tensors provided.