espnet2.tasks.enh.EnhancementTask
espnet2.tasks.enh.EnhancementTask
class espnet2.tasks.enh.EnhancementTask
Bases: AbsTask
EnhancementTask is a task class for audio enhancement in the ESPnet framework.
This class defines various configurations and methods for building and training models related to audio enhancement tasks. It manages the different components of the enhancement pipeline, including encoders, separators, decoders, and loss functions.
num_optimizers
Number of optimizers to be used for training.
- Type: int
class_choices_list
List of available class choices for various components such as encoders, separators, and loss wrappers.
- Type: list
trainer
The trainer class used for training the model.
- Type:Trainer
add_task_arguments(parser
argparse.ArgumentParser): Adds command-line arguments specific to the enhancement task.
build_collate_fn(args
argparse.Namespace, train: bool) -> Callable: Builds a collate function for data processing during training and evaluation.
build_preprocess_fn(args
argparse.Namespace, train: bool) -> Optional[Callable]: Builds a preprocessing function for preparing the data.
required_data_names(train
bool = True, inference: bool = False) -> Tuple[str, …]: Returns the required data names based on the mode.
optional_data_names(train
bool = True, inference: bool = False) -> Tuple[str, …]: Returns the optional data names based on the mode.
build_model(args
argparse.Namespace) -> ESPnetEnhancementModel: Constructs the enhancement model based on the provided configurations.
build_iter_factory(args
argparse.Namespace, distributed_option: DistributedOption, mode: str, kwargs: dict = None) -> AbsIterFactory: Builds an iterator factory for data loading during training or evaluation.
################### Examples
To create an enhancement task and add arguments:
python parser = argparse.ArgumentParser() EnhancementTask.add_task_arguments(parser) args = parser.parse_args()
To build a model for enhancement:
python model = EnhancementTask.build_model(args)
######### NOTE The class is designed to work with the ESPnet framework and its various components. Ensure that all necessary dependencies are installed and configured properly.
classmethod add_task_arguments(parser: ArgumentParser)
Adds task-specific arguments to the argument parser.
This method extends the provided argparse.ArgumentParser with options related to the enhancement task, including model configuration, criteria, preprocessing options, and more. It organizes the arguments into logical groups for better clarity.
- Parameters:
- cls – The class reference to access class-level attributes and methods.
- parser (argparse.ArgumentParser) – The parser to which arguments will be added.
################### Examples
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> EnhancementTask.add_task_arguments(parser)
>>> args = parser.parse_args()
######### NOTE
- The method cannot use required=True for certain arguments due to the need for optional configurations.
- The method appends various argument groups such as “Task related” and “Preprocess related”.
- Raises:ValueError – If the specified preprocessor type is not supported.
classmethod build_collate_fn(args: Namespace, train: bool) → Callable[[Collection[Tuple[str, Dict[str, ndarray]]]], Tuple[List[str], Dict[str, Tensor]]]
Builds a collate function for data loading.
This method constructs a collate function suitable for use with data loaders, which can handle a collection of data samples and pad them to ensure uniformity in batch sizes.
- Parameters:
- args (argparse.Namespace) – Command line arguments containing configuration settings.
- train (bool) – Indicates whether the function is being called in training mode.
- Returns: Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[List[str], Dict[str, torch.Tensor]]]: A callable function that can be used to collate data samples.
################### Examples
>>> from espnet2.tasks.enhancement_task import EnhancementTask
>>> args = argparse.Namespace()
>>> args.float_pad_value = 0.0
>>> args.int_pad_value = 0
>>> collate_fn = EnhancementTask.build_collate_fn(args, train=True)
>>> data = [("sample1", {"feature": np.array([1, 2, 3])}),
... ("sample2", {"feature": np.array([4, 5])})]
>>> collated_data = collate_fn(data)
>>> print(collated_data)
(['sample1', 'sample2'], {'feature': tensor([[1, 2, 3], [4, 5, 0]])})
classmethod build_iter_factory(args: Namespace, distributed_option: DistributedOption, mode: str, kwargs: dict | None = None) → AbsIterFactory
Builds an iterator factory for the enhancement task.
This method creates an instance of an iterator factory, which is responsible for managing the data loading process during training or evaluation. It takes into account whether dynamic mixing is enabled and modifies the arguments accordingly for training mode.
- Parameters:
- args (argparse.Namespace) – The command line arguments containing configurations for the task.
- distributed_option (DistributedOption) – Options related to distributed training.
- mode (str) – The mode of operation, typically “train” or “eval”.
- kwargs (dict , optional) – Additional keyword arguments for the iterator factory. Defaults to None.
- Returns: An instance of a class derived from AbsIterFactory that will be used to iterate over the dataset.
- Return type:AbsIterFactory
################### Examples
>>> from espnet2.tasks.enhancement_task import EnhancementTask
>>> args = argparse.Namespace(dynamic_mixing=True, fold_length=[512])
>>> distributed_option = DistributedOption(...)
>>> factory = EnhancementTask.build_iter_factory(args, distributed_option, "train")
######### NOTE If dynamic mixing is enabled and the mode is “train”, the fold_length argument will be modified to only use the first value.
classmethod build_model(args: Namespace) → ESPnetEnhancementModel
Builds the enhancement model based on the provided configuration arguments.
This method constructs an enhancement model by selecting the appropriate encoder, separator, decoder, and optional mask module according to the specified arguments. It also initializes the model if a specific initialization method is provided.
- Parameters:args (argparse.Namespace) – The command line arguments containing model configuration options, including encoder, separator, decoder, and loss criteria.
- Returns: An instance of the enhancement model configured : with the specified components.
- Return type:ESPnetEnhancementModel
################### Examples
>>> import argparse
>>> args = argparse.Namespace(
... encoder='stft',
... encoder_conf={'param1': value1},
... separator='rnn',
... separator_conf={'param2': value2},
... decoder='stft',
... decoder_conf={'param3': value3},
... mask_module='multi_mask',
... mask_module_conf={'param4': value4},
... criterions=[
... {
... "name": "si_snr",
... "conf": {},
... "wrapper": "fixed_order",
... "wrapper_conf": {},
... },
... ],
... init='xavier_uniform',
... )
>>> model = EnhancementTask.build_model(args)
######### NOTE Ensure that the args object contains valid configuration options for the encoder, separator, decoder, and any other necessary parameters to avoid runtime errors.
- Raises:ValueError – If the preprocessor type is not supported or if there are issues with the provided arguments.
classmethod build_preprocess_fn(args: Namespace, train: bool) → Callable[[str, Dict[str, array]], Dict[str, ndarray]] | None
Builds a preprocessing function based on the provided arguments.
This method generates a callable function that preprocesses input audio data based on the specified preprocessor type. It checks if a preprocessor is defined in the arguments and initializes the corresponding preprocessor class with the provided configuration.
- Parameters:
- cls – The class reference to the current class.
- args (argparse.Namespace) – The parsed command line arguments.
- train (bool) – A flag indicating whether the function is being called for training or inference.
- Returns: A callable preprocessing function if a preprocessor is defined, otherwise None.
- Return type: Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]
- Raises:ValueError – If the specified preprocessor type is not supported.
################### Examples
>>> args = argparse.Namespace()
>>> args.preprocessor = "dynamic_mixing"
>>> args.preprocessor_conf = {"source_scp_name": "spk1.scp", "ref_num": 2}
>>> preprocess_fn = EnhancementTask.build_preprocess_fn(args, train=True)
>>> output = preprocess_fn("input.wav", {"key": np.array([1, 2, 3])})
class_choices_list
num_optimizers
classmethod optional_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Returns the optional data names required for the enhancement task.
This method generates a tuple of optional data names that may be used during the training or inference phases of the enhancement task. The data names returned include references for dereverberation, speech, noise, as well as category and sampling frequency information.
- Parameters:
- train (bool) – If True, the method is invoked in training mode. Defaults to True.
- inference (bool) – If True, the method is invoked in inference mode. Defaults to False.
- Returns: A tuple containing the names of optional data : required for the task.
- Return type: Tuple[str, …]
################### Examples
>>> optional_data = EnhancementTask.optional_data_names(train=True)
>>> print(optional_data)
('speech_mix', 'dereverb_ref1', 'dereverb_ref2', ..., 'category', 'fs')
>>> optional_data = EnhancementTask.optional_data_names(inference=True)
>>> print(optional_data)
('speech_mix',)
######### NOTE The maximum number of reference signals is defined by the constant MAX_REFERENCE_NUM.
classmethod required_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Returns the required data names for the enhancement task.
The method returns a tuple of strings representing the names of the required data based on the task mode (training or inference).
- Parameters:
- train (bool) – Indicates whether the task is in training mode. Defaults to True.
- inference (bool) – Indicates whether the task is in inference mode. Defaults to False.
- Returns: A tuple containing the required data names. : If not in inference mode, it returns (“speech_ref1”,). If in inference mode, it returns (“speech_mix”,).
- Return type: Tuple[str, …]
################### Examples
>>> EnhancementTask.required_data_names(train=True, inference=False)
('speech_ref1',)
>>> EnhancementTask.required_data_names(train=True, inference=True)
('speech_mix',)
trainer
alias of Trainer