espnet2.tasks.enh_tse.TargetSpeakerExtractionTask
espnet2.tasks.enh_tse.TargetSpeakerExtractionTask
class espnet2.tasks.enh_tse.TargetSpeakerExtractionTask
Bases: AbsTask
TargetSpeakerExtractionTask is a class that defines the task for target speaker
extraction in a multi-speaker environment. It extends the AbsTask class and provides methods for adding task-specific arguments, building collate and preprocess functions, and constructing the model.
num_optimizers
Number of optimizers to use. Default is 1.
- Type: int
class_choices_list
List of class choices for encoder, extractor, decoder, and preprocessor.
- Type: List[ClassChoices]
trainer
Trainer class to modify train() or eval() procedures.
Type:Trainer
Parameters:parser (argparse.ArgumentParser) – Argument parser instance to which task related arguments will be added.
Returns: A function that collates data during training or evaluation.
Return type: Callable
Yields:Optional[Callable] – A function for preprocessing input data.
Raises:None –
################# Examples
To add task arguments to an argument parser:
parser = argparse.ArgumentParser() TargetSpeakerExtractionTask.add_task_arguments(parser)
To build a collate function:
collate_fn = TargetSpeakerExtractionTask.build_collate_fn(args, train=True)
To build a preprocess function:
preprocess_fn = TargetSpeakerExtractionTask.build_preprocess_fn(args, train=True)
To build a model:
model = TargetSpeakerExtractionTask.build_model(args)
########### NOTE The class uses specific choices for various components such as encoder, extractor, decoder, and preprocessor.
classmethod add_task_arguments(parser: ArgumentParser)
Adds task-specific arguments to the provided argument parser.
This method is intended to extend the argument parser with options related to the Target Speaker Extraction Task, including model configuration, preprocessing parameters, and training options.
- Parameters:
- cls – The class reference.
- parser (argparse.ArgumentParser) – The argument parser instance to which the arguments will be added.
################# Examples
To use this method, you can create an argument parser and call the method as follows:
python import argparse parser = argparse.ArgumentParser() TargetSpeakerExtractionTask.add_task_arguments(parser) args = parser.parse_args()
########### NOTE The method uses NestedDictAction to allow for nested configuration in the arguments.
- Raises:ValueError – If an invalid argument is provided or if required arguments are missing.
classmethod build_collate_fn(args: Namespace, train: bool) → Callable[[Collection[Tuple[str, Dict[str, ndarray]]]], Tuple[List[str], Dict[str, Tensor]]]
Build a collate function for batching input data during training or evaluation.
This method returns a callable that is used to collate a list of data samples into a batch. The collate function is essential for preparing input data for the model during training or evaluation processes.
- Parameters:
- args (argparse.Namespace) – Command-line arguments that may contain configuration settings for the task.
- train (bool) – A flag indicating whether the function is being used for training or evaluation.
- Returns: Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], : > Tuple[List[str], Dict[str, torch.Tensor]]]: <br/> A collate function that takes a collection of tuples containing data samples and returns a tuple consisting of a list of keys and a dictionary of batched data as PyTorch tensors.
################# Examples
>>> collate_fn = TargetSpeakerExtractionTask.build_collate_fn(args, True)
>>> batch = collate_fn([
... ("sample1", {"data": np.array([1, 2, 3])}),
... ("sample2", {"data": np.array([4, 5])}),
... ])
>>> print(batch)
(['sample1', 'sample2'], {'data': tensor([[1, 2, 3],
[4, 5, 0]])})
########### NOTE The collate function pads the input data to ensure that all samples in the batch have the same shape.
classmethod build_model(args: Namespace) → ESPnetExtractionModel
Builds and initializes the ESPnetExtractionModel for target speaker extraction.
This method creates the components of the target speaker extraction model, including the encoder, extractor, decoder, and loss wrappers, based on the provided configuration arguments. It also initializes the model using the specified initialization method.
- Parameters:args (argparse.Namespace) – Command-line arguments containing model configurations such as encoder, extractor, decoder, and criterions.
- Returns: An instance of the ESPnetExtractionModel that is : built and initialized according to the provided configurations.
- Return type:ESPnetExtractionModel
################# Examples
Example usage:
args = parser.parse_args() model = TargetSpeakerExtractionTask.build_model(args)
########### NOTE Ensure that the provided arguments contain valid configurations for the encoder, extractor, and decoder classes, as well as any required criterion and wrapper configurations.
classmethod build_preprocess_fn(args: Namespace, train: bool) → Callable[[str, Dict[str, array]], Dict[str, ndarray]] | None
Build a preprocessing function for the TargetSpeakerExtractionTask.
This method constructs a preprocessing function that can be used to prepare the input data for the target speaker extraction task. It configures the preprocessor based on the provided arguments and returns a callable that processes the input data accordingly.
- Parameters:
- cls – The class itself (TargetSpeakerExtractionTask).
- args (argparse.Namespace) – The command-line arguments containing the configuration for the preprocessing.
- train (bool) – A flag indicating whether the function is being built for training or evaluation.
- Returns: A callable function that takes a file path and a dictionary of data, returning a processed dictionary of numpy arrays. If the preprocessing function cannot be created, returns None.
- Return type: Optional[Callable[[str, Dict[str, np.ndarray]], Dict[str, np.ndarray]]]
################# Examples
>>> args = argparse.Namespace(
... train_spk2enroll='path/to/spk2enroll.scp',
... enroll_segment=None,
... load_spk_embedding=False,
... load_all_speakers=False,
... rir_scp=None,
... rir_apply_prob=1.0,
... noise_scp=None,
... noise_apply_prob=1.0,
... noise_db_range='13_15',
... short_noise_thres=0.5,
... speech_volume_normalize=None,
... use_reverberant_ref=False,
... num_spk=1,
... num_noise_type=1,
... sample_rate=8000,
... force_single_channel=False,
... channel_reordering=False,
... categories=None,
... speech_segment=None,
... avoid_allzero_segment=True,
... flexible_numspk=False,
... preprocessor_conf={}
... )
>>> preprocess_fn = TargetSpeakerExtractionTask.build_preprocess_fn(args, True)
>>> processed_data = preprocess_fn('path/to/audio.wav', {'key': np.array([1, 2, 3])})
########### NOTE This method relies on the TSEPreprocessor class for actual preprocessing functionality.
class_choices_list
num_optimizers
classmethod optional_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Retrieves the optional data names used in the Target Speaker Extraction task.
This method generates a tuple of optional data names that may be utilized during the training or inference process. The optional data names include additional enrollment and reference speech data based on the maximum number of references allowed.
- Parameters:
- train (bool) – Indicates if the method is called during training. Defaults to True.
- inference (bool) – Indicates if the method is called during inference. Defaults to False.
- Returns: A tuple containing the optional data names.
- Return type: Tuple[str, …]
################# Examples
>>> optional_data = TargetSpeakerExtractionTask.optional_data_names()
>>> print(optional_data)
('enroll_ref2', 'enroll_ref3', ..., 'category')
########### NOTE The number of enrollment and reference speech data names is determined by the constant MAX_REFERENCE_NUM. The method ensures that the correct number of enrollment and reference names are returned based on whether the first reference exists.
classmethod required_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Returns the required data names for the target speaker extraction task.
This method provides the names of the required data inputs for training or inference. The output varies depending on whether the task is in inference mode or not.
- Parameters:
- train (bool) – Indicates whether the data is for training. Default is True.
- inference (bool) – Indicates whether the data is for inference. Default is False.
- Returns: A tuple containing the names of the required data.
- Return type: Tuple[str, …]
################# Examples
>>> required_data_names(train=True, inference=False)
('speech_mix', 'enroll_ref1', 'speech_ref1')
>>> required_data_names(train=True, inference=True)
('speech_mix', 'enroll_ref1',)
########### NOTE In training mode, both “speech_ref1” and “enroll_ref1” are required, while in inference mode, only “speech_mix” and “enroll_ref1” are required.
trainer
alias of Trainer