espnet2.tasks.abs_task.AbsTask
espnet2.tasks.abs_task.AbsTask
class espnet2.tasks.abs_task.AbsTask
Bases: ABC
Abstract task module.
The AbsTask class serves as an abstract base class for defining various tasks in the ESPnet framework. It outlines the necessary methods and attributes that must be implemented by any specific task subclass.
num_optimizers
The number of optimizers used in the task.
- Type: int
trainer
The trainer class associated with the task.
- Type:Trainer
class_choices_list
List of class choices for configuration.
- Type: List[ClassChoices]
add_task_arguments(parser
argparse.ArgumentParser): Abstract method to add task-specific arguments to the parser.
build_collate_fn(args
argparse.Namespace, train: bool) -> Callable: Returns a collate function for use with DataLoader.
build_preprocess_fn(args
argparse.Namespace, train: bool) -> Optional[Callable]: Returns a preprocessing function for input data.
required_data_names(train
bool = True, inference: bool = False) -> Tuple[str, …]: Defines required data names for the task.
optional_data_names(train
bool = True, inference: bool = False) -> Tuple[str, …]: Defines optional data names for the task.
build_model(args
argparse.Namespace) -> AbsESPnetModel: Builds and returns a model instance based on provided arguments.
get_parser() → config_argparse.ArgumentParser
Returns a parser for command line arguments.
build_optimizers(args
argparse.Namespace, model: torch.nn.Module) -> List[torch.optim.Optimizer]: Builds and returns a list of optimizers based on the provided model.
exclude_opts() → Tuple[str, ...]
Returns options not to be shown by the –print_config argument.
get_default_config() → Dict[str, Any]
Returns the default configuration as a dictionary.
check_required_command_args(args
argparse.Namespace): Checks if required command-line arguments are provided.
check_task_requirements(dataset
Union[AbsDataset, IterableESPnetDataset], : allow_variable_data_keys: bool, train: bool, inference: bool = False):
Validates if the dataset meets the task’s requirements.
print_config(file=sys.stdout) → None
Prints the default configuration to the specified output file.
main(args
Optional[argparse.Namespace] = None, cmd: Optional[Sequence[str]] = None): Main entry point for executing the task.
build_iter_factory(args
argparse.Namespace, : distributed_option: DistributedOption, mode: str, kwargs: Optional[dict] = None) -> AbsIterFactory:
Builds a factory for creating mini-batch iterators.
build_model_from_file(config_file
Optional[Union[Path, str]] = None, : model_file: Optional[Union[Path, str]] = None, device: str = “cpu”) -> Tuple[AbsESPnetModel, argparse.Namespace]:
Builds a model from configuration and model files for inference or fine-tuning.
############################################
Example
>>> class MyTask(AbsTask):
... @classmethod
... def add_task_arguments(cls, parser: argparse.ArgumentParser):
... parser.add_argument("--my_arg", type=int, help="My task argument")
...
... @classmethod
... def build_collate_fn(cls, args: argparse.Namespace, train: bool):
... # Implement collate function
... pass
>>> parser = MyTask.get_parser()
>>> MyTask.add_task_arguments(parser)
>>> args = parser.parse_args()
abstract classmethod add_task_arguments(parser: ArgumentParser)
Add task-specific arguments to the argument parser.
This method should be overridden in subclasses to add task-specific command-line arguments to the provided argument parser.
- Parameters:parser (argparse.ArgumentParser) – The argument parser instance to which task-specific arguments should be added.
- Raises:NotImplementedError – If not overridden in a subclass.
############################################
Example
>>> class YourTask(AbsTask):
... @classmethod
... def add_task_arguments(cls, parser: argparse.ArgumentParser):
... parser.add_argument("--your_task_param", type=int, default=0,
... help="An example parameter for your task.")
classmethod build_category_chunk_iter_factory(args: Namespace, iter_options: IteratorOptions, mode: str) → AbsIterFactory
Build a factory object for category chunk mini-batch iterator.
This factory is responsible for creating mini-batches that are grouped by categories, with each batch containing a specified number of samples from each category. It ensures that the batches are balanced in terms of category representation.
- Parameters:
- args (argparse.Namespace) – Parsed command line arguments containing configurations for the iterator.
- iter_options (IteratorOptions) – Options containing various settings for iterator behavior, such as batch size and preprocessing functions.
- mode (str) – The operational mode of the iterator, which could be “train”, “valid”, or “test”.
- Returns: An instance of the iterator factory that produces category chunk mini-batches.
- Return type:AbsIterFactory
- Raises:RuntimeError – If the dataset does not satisfy the requirements defined for the task.
############################################
Example
>>> iter_factory = YourTask.build_category_chunk_iter_factory(
... args, iter_options, mode='train'
... )
>>> for keys, batch in iter_factory.build_iter(epoch):
... model(**batch)
################## NOTE The batches created by this factory will ensure that each mini-batch has a balanced representation of each category defined in the dataset.
classmethod build_category_iter_factory(args: Namespace, iter_options: IteratorOptions, mode: str) → AbsIterFactory
Build a factory object for category-based mini-batch iteration.
This method creates a mini-batch iterator specifically for categories, utilizing a dataset that is expected to have a mapping from categories to utterances. The function checks for required data names and ensures that the necessary category mappings are in place.
- Parameters:
- args (argparse.Namespace) – The parsed command-line arguments.
- iter_options (IteratorOptions) – Options for iterator construction.
- mode (str) – The mode of operation (e.g., “train”, “valid”).
- Returns: An instance of a factory that produces category-based mini-batches.
- Return type:AbsIterFactory
- Raises:
- RuntimeError – If the dataset does not meet task requirements.
- ValueError – If the category2utt file is not found.
Example
>>> factory = YourTask.build_category_iter_factory(args, iter_options, mode)
>>> for keys, batch in factory.build_iter(epoch):
... model(**batch)
################## NOTE The category2utt file must exist in the same directory as the data files. This file maps categories to utterances, and is mandatory for the category iterator to function correctly.
classmethod build_chunk_iter_factory(args: Namespace, iter_options: IteratorOptions, mode: str) → AbsIterFactory
Build a chunk iterator factory for creating mini-batches.
This method creates a factory for producing mini-batches from the dataset based on the specified chunk length and other parameters. It is primarily used for tasks where data needs to be processed in chunks, such as audio processing.
- Parameters:
- args – The command line arguments namespace containing configuration parameters.
- iter_options – An instance of IteratorOptions that contains various options related to data iteration.
- mode – A string indicating the mode of operation, such as “train” or “valid”.
- Returns: An instance of AbsIterFactory, which can be used to iterate over the dataset in chunks.
- Raises:RuntimeError – If the number of samples is smaller than the world size or if the batch size is not compatible with the world size.
############################################
Example
>>> factory = AbsTask.build_chunk_iter_factory(args, iter_options, mode)
>>> for batch in factory.build_iter(epoch):
... process_batch(batch)
################## NOTE This factory assumes that the input data has been preprocessed according to the specified preprocess_fn and is ready for chunk-based iteration.
abstract classmethod build_collate_fn(args: Namespace, train: bool) → Callable[[Sequence[Dict[str, ndarray]]], Dict[str, Tensor]]
Return a callable collate function for DataLoader.
The collate function is responsible for merging a list of samples into a mini-batch. It is typically used in conjunction with a DataLoader to create batches of data.
- Parameters:
- cls – The class type.
- args – The arguments namespace containing configuration options.
- train – A boolean indicating if the collate function is for training or validation.
- Returns: A callable that takes a sequence of data samples (each a dictionary) and returns a single dictionary that represents the batched data. Each value in the dictionary is expected to be a tensor.
############################################
Example
>>> from torch.utils.data import DataLoader
>>> loader = DataLoader(
... dataset,
... collate_fn=cls.build_collate_fn(args, train=True),
... ...
... )
In many cases, you can use our common collate_fn.
classmethod build_iter_factory(args: Namespace, distributed_option: DistributedOption, mode: str, kwargs: dict | None = None) → AbsIterFactory
Build a factory object of mini-batch iterator.
This object is invoked at every epoch to build the iterator for each epoch as follows:
>>> iter_factory = cls.build_iter_factory(...)
>>> for epoch in range(1, max_epoch):
... for keys, batch in iter_factory.build_iter(epoch):
... model(**batch)
The mini-batches for each epoch are fully controlled by this class. Note that the random seed used for shuffling is decided as “seed + epoch” and the generated mini-batches can be reproduced when resuming.
Note that the definition of “epoch” doesn’t always indicate running out of the whole training corpus. The “–num_iters_per_epoch” option restricts the number of iterations for each epoch and the rest of samples for the originally epoch are left for the next epoch. For example, if the number of mini-batches equals 4, the following two scenarios are the same:
- 1 epoch without “–num_iters_per_epoch”
- 4 epochs with “–num_iters_per_epoch” == 1
- Parameters:
- args – The arguments namespace containing configurations for the iterator factory.
- distributed_option – Options related to distributed training.
- mode – The mode for which to build the iterator (e.g., “train”, “valid”, “plot_att”).
- kwargs – Optional additional keyword arguments to overwrite default iterator options.
- Returns: An instance of AbsIterFactory for creating mini-batch iterators.
- Raises:RuntimeError – If the specified iterator type is not supported.
classmethod build_iter_options(args: Namespace, distributed_option: DistributedOption, mode: str)
Build iterator options for training, validation, or plotting attention.
This method constructs an IteratorOptions object that encapsulates the necessary parameters to create an iterator for a specified mode (train, valid, or plot_att). It checks the mode and sets the appropriate options for preprocessing, collating, and batching data.
- Parameters:
- cls – The class itself (used as a reference to call class methods).
- args (argparse.Namespace) – The command-line arguments parsed into a namespace object.
- distributed_option (DistributedOption) – The distributed training options to determine if the training is distributed.
- mode (str) – The mode for which to build the iterator options. This can be one of “train”, “valid”, or “plot_att”.
- Returns: An instance of IteratorOptions containing the : configuration needed for the iterator.
- Return type:IteratorOptions
- Raises:NotImplementedError – If an unsupported mode is specified.
############################################
Example
>>> options = AbsTask.build_iter_options(args, distributed_option, mode='train')
>>> print(options.batch_size)
32
>>> options = AbsTask.build_iter_options(args, distributed_option, mode='valid')
>>> print(options.max_cache_size)
10485760 # 10MB if set in args
abstract classmethod build_model(args: Namespace) → AbsESPnetModel
Build the model based on the provided arguments.
This method should create and return an instance of a model that inherits from AbsESPnetModel. The specifics of the model to be built are determined by the configuration provided in args.
- Parameters:args – An instance of argparse.Namespace that contains the configuration options needed to build the model. This may include parameters such as model architecture, layer sizes, and other hyperparameters.
- Returns: An instance of AbsESPnetModel or its subclass that has been initialized according to the provided arguments.
- Raises:NotImplementedError – If the method is not overridden in a subclass.
############################################
Example
>>> from espnet2.train.abs_espnet_model import AbsESPnetModel
>>> class MyModel(AbsESPnetModel):
... def __init__(self, param1, param2):
... # Model initialization logic
...
>>> class MyTask(AbsTask):
... @classmethod
... def build_model(cls, args):
... return MyModel(args.param1, args.param2)
################## NOTE This method is intended to be implemented in subclasses of AbsTask to define how the model should be constructed based on task-specific requirements.
classmethod build_model_from_file(config_file: Path | str | None = None, model_file: Path | str | None = None, device: str = 'cpu') → Tuple[AbsESPnetModel, Namespace]
Build model from the files.
This method is used for inference or fine-tuning.
- Parameters:
- config_file – The yaml file saved when training.
- model_file – The model file saved when training.
- device – Device type, “cpu”, “cuda”, or “cuda:N”.
- Returns: A tuple containing the constructed model and the arguments as a Namespace object.
- Raises:
- AssertionError – If model_file is None and config_file
- is also None. –
############################################
Example
>>> model, args = YourTask.build_model_from_file(
... config_file='config.yaml',
... model_file='model.pth',
... device='cuda'
... )
>>> model, args = YourTask.build_model_from_file(
... model_file='model.pth'
... )
classmethod build_multiple_iter_factory(args: Namespace, distributed_option: DistributedOption, mode: str)
Build a factory for creating multiple mini-batch iterators.
This method is responsible for constructing multiple iterator factories based on the provided arguments and distributed options. It checks the directories for splits and prepares functions to build iterators for each split.
- Parameters:
- cls – The class reference.
- args (argparse.Namespace) – Command line arguments containing the necessary parameters for iterator creation.
- distributed_option (DistributedOption) – Options related to distributed training.
- mode (str) – The mode for which the iterator is being built, such as “train” or “valid”.
- Returns: An instance of MultipleIterFactory that manages multiple iterators for different splits.
- Return type:MultipleIterFactory
- Raises:
- RuntimeError – If any specified path is not a directory or if the number of splits do not match across different paths.
- FileNotFoundError – If a required split file or directory does not exist.
############################################
Example
>>> factory = MyTask.build_multiple_iter_factory(args, distributed_option, mode)
>>> for iter_factory in factory.build_iter(epoch):
... for keys, batch in iter_factory:
... model(**batch)
################## NOTE The function expects that the specified directories contain a “num_splits” file which indicates the number of data splits available for training or validation.
classmethod build_optimizers(args: Namespace, model: Module) → List[Optimizer]
Build optimizers for the model based on the provided arguments.
This method is responsible for creating and configuring the optimizer for the model. If the number of optimizers defined in the class is greater than one, this method must be overridden in the subclass.
- Parameters:
- args (argparse.Namespace) – Command-line arguments containing optimizer configurations.
- model (torch.nn.Module) – The model for which the optimizer will be created.
- Returns: A list of optimizers for the model.
- Return type: List[torch.optim.Optimizer]
- Raises:
- RuntimeError – If num_optimizers is not equal to 1.
- ValueError – If the specified optimizer is not in the list of supported optimizers.
- RuntimeError – If fairscale is required but not installed when using sharded DDP.
############################################
Example
>>> args = argparse.Namespace()
>>> args.optim = 'adam'
>>> args.optim_conf = {'lr': 0.001}
>>> model = torch.nn.Linear(10, 2)
>>> optimizers = AbsTask.build_optimizers(args, model)
>>> print(type(optimizers[0])) # Output: <class 'torch.optim.adam.Adam'>
################## NOTE The method expects args.optim to match one of the keys in optim_classes. The optim_conf should contain any additional parameters required by the chosen optimizer.
abstract classmethod build_preprocess_fn(args: Namespace, train: bool) → Callable[[str, Dict[str, array]], Dict[str, ndarray]] | None
Build a preprocessing function for input data.
This method is expected to return a callable that processes the input data and returns a structured output, which is typically a dictionary containing the required tensors.
- Parameters:
- cls – The class that is invoking this method.
- args – Command line arguments parsed into a namespace.
- train – A boolean indicating whether the function is being built for training or validation/testing.
- Returns: A callable that takes a string (input key) and a dictionary of input data, returning a processed dictionary of numpy arrays, or None if no preprocessing is needed.
############################################
Example
>>> preprocess_fn = cls.build_preprocess_fn(args, train=True)
>>> processed_data = preprocess_fn("input_key", {"input_key": np.array(...)})
- Raises:
- NotImplementedError – If the method is not implemented in
- the derived class. –
classmethod build_sequence_iter_factory(args: Namespace, iter_options: IteratorOptions, mode: str) → AbsIterFactory
Build a sequence iterator factory for data loading.
This method constructs a sequence iterator that is responsible for generating mini-batches from the dataset for training or validation.
It checks the requirements of the task and initializes the dataset and batch sampler based on the provided arguments. It also handles the shuffling and splitting of batches for distributed training if necessary.
- Parameters:
- args (argparse.Namespace) – Command line arguments containing configurations for the iterator.
- iter_options (IteratorOptions) – Configuration options for the iterator, including preprocessing, collate functions, and batch specifications.
- mode (str) – The mode in which the iterator is used (e.g., “train” or “valid”).
- Returns: An instance of a sequence iterator factory for data loading.
- Return type:AbsIterFactory
- Raises:RuntimeError – If the batch size is less than the world size in distributed mode or if required data names are not satisfied.
############################################
Example
>>> iter_factory = cls.build_sequence_iter_factory(args, iter_options, mode)
>>> for epoch in range(max_epoch):
... for keys, batch in iter_factory.build_iter(epoch):
... model(**batch)
################## NOTE The dataset is constructed from the ESPnetDataset class, and the batch sampler is built using the specified batch type and other configurations.
classmethod build_streaming_iterator(data_path_and_name_and_type, preprocess_fn, collate_fn, key_file: str | None = None, batch_size: int = 1, dtype: str = <class 'numpy.float32'>, num_workers: int = 1, allow_variable_data_keys: bool = False, ngpu: int = 0, inference: bool = False, mode: str | None = None, multi_task_dataset: bool = False) → DataLoader
Build DataLoader using iterable dataset.
This method creates a DataLoader that allows streaming of data from an iterable dataset, which can be particularly useful for large datasets that do not fit into memory.
- Parameters:
- data_path_and_name_and_type – A list containing tuples of (file path, key name, data type) for the dataset.
- preprocess_fn – A callable function that processes the data before feeding it to the model.
- collate_fn – A callable function that collates a list of samples into a mini-batch.
- key_file – Optional path to a key file that maps keys to their respective data samples.
- batch_size – Number of samples per batch (default is 1).
- dtype – Data type of the dataset (default is np.float32).
- num_workers – Number of subprocesses to use for data loading (default is 1).
- allow_variable_data_keys – Whether to allow arbitrary data keys in the mini-batch (default is False).
- ngpu – Number of GPUs to use (default is 0 for CPU).
- inference – Flag indicating whether the DataLoader is used for inference (default is False).
- mode – Optional mode for the DataLoader (e.g., ‘train’, ‘valid’).
- multi_task_dataset – Whether the dataset is organized for multi-task learning (default is False).
- Returns: A DataLoader instance configured with the specified options.
- Return type: DataLoader
- Raises:
- RuntimeError – If the dataset does not meet the required
- data specifications. –
############################################
Example
>>> data_loader = cls.build_streaming_iterator(
... data_path_and_name_and_type=[("data/train.scp", "train", "sound")],
... preprocess_fn=my_preprocess_function,
... collate_fn=my_collate_function,
... batch_size=32,
... num_workers=4
... )
################## NOTE The DataLoader will use the provided preprocess_fn to process each sample and collate_fn to collate samples into batches. If multi_task_dataset is set to True, the DataLoader will handle multi-task data formats.
classmethod build_task_iter_factory(args: Namespace, iter_options: IteratorOptions, mode: str) → AbsIterFactory
Build task specific iterator factory.
This method is intended to be overridden in subclasses to create an iterator factory specific to the task at hand. It is invoked to generate an iterator for each epoch during training or validation.
- Parameters:
- args (argparse.Namespace) – The parsed command-line arguments.
- iter_options (IteratorOptions) – Options related to data loading and batching.
- mode (str) – The mode of operation (e.g., “train”, “valid”).
- Returns: An instance of an iterator factory for the specific task.
- Return type:AbsIterFactory
Example
>>> class YourTask(AbsTask):
... @classmethod
... def add_task_arguments(cls, parser: argparse.ArgumentParser):
... parser.set_defaults(iterator_type="task")
...
... @classmethod
... def build_task_iter_factory(
... cls,
... args: argparse.Namespace,
... iter_options: IteratorOptions,
... mode: str,
... ):
... return FooIterFactory(...)
...
... @classmethod
... def build_iter_options(
... cls,
... args: argparse.Namespace,
... distributed_option: DistributedOption,
... mode: str
... ):
... # if you need to customize options object
classmethod check_required_command_args(args: Namespace)
Checks the required command-line arguments for the task.
This method verifies that all required arguments specified in the parser are provided by the user. If any required arguments are missing, it raises a runtime error and displays the help message for the parser.
- Parameters:args (argparse.Namespace) – The namespace containing the command-line arguments.
- Raises:RuntimeError – If any required arguments are missing.
############################################
Example
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> parser.add_argument('--output_dir', required=True)
>>> args = parser.parse_args(['--output_dir', 'path/to/output'])
>>> check_required_command_args(args) # This will pass
>>> args = parser.parse_args([]) # Missing required argument
>>> check_required_command_args(args) # This will raise RuntimeError
classmethod check_task_requirements(dataset: AbsDataset | IterableESPnetDataset, allow_variable_data_keys: bool, train: bool, inference: bool = False) → None
Check if the dataset satisfies the requirement of the current Task.
This method verifies that the dataset has the required data names specified by the task and checks if variable data keys are allowed.
- Parameters:
- cls – The class method reference.
- dataset (Union [AbsDataset , IterableESPnetDataset ]) – The dataset to check against the task requirements.
- allow_variable_data_keys (bool) – Flag to indicate if variable data keys are permitted.
- train (bool) – Flag to indicate if the check is for training.
- inference (bool , optional) – Flag to indicate if the check is for inference. Defaults to False.
- Raises:RuntimeError – If the dataset does not contain the required data names or if variable data keys are not allowed and the dataset contains additional keys.
################## NOTE If you intend to use an additional input, modify “{cls._name_}.required_data_names()” or “{cls._name_}.optional_data_names()”. Otherwise, you need to set –allow_variable_data_keys true.
############################################
Example
>>> class MyTask(AbsTask):
... @classmethod
... def required_data_names(cls, train=True, inference=False):
... return ("input", "output")
...
>>> dataset = ... # some dataset instance
>>> MyTask.check_task_requirements(dataset, False, True)
class_choices_list
classmethod exclude_opts() → Tuple[str, ...]
The options not to be shown by –print_config.
This method specifies a tuple of command-line options that should be excluded from the printed configuration output when the –print_config flag is used. This is useful for hiding options that are not relevant to the user or for internal configurations that should not be displayed.
- Returns: A tuple containing the names of options to be excluded from the configuration output.
- Return type: Tuple[str, …]
############################################
Example
>>> print(cls.exclude_opts())
('required', 'print_config', 'config', 'ngpu')
classmethod get_default_config() → Dict[str, Any]
Return the configuration as a dictionary.
This method retrieves the default configuration options used by the task, which can be helpful for printing the configuration or validating arguments. It populates the configuration dictionary based on the command line arguments parsed from the task’s argument parser, while excluding certain options.
- Returns: A dictionary containing the default configuration.
- Return type: dict
- Raises:ValueError – If an invalid optimizer or scheduler name is provided in the configuration.
############################################
Example
>>> config = YourTask.get_default_config()
>>> print(config)
{'optim': 'adam', 'learning_rate': 0.001, ...}
################## NOTE This method is used by the print_config() method to display the current configuration.
classmethod get_parser() → ArgumentParser
Returns a parser for command-line arguments.
This method constructs an argument parser for the task’s command-line interface. It includes common configuration options, distributed training options, and task-specific arguments. The resulting parser can be used to parse command-line arguments when running a training script.
- Parameters:cls – The class that is calling this method. It is expected to be a subclass of AbsTask.
- Returns: An instance of config_argparse.ArgumentParser configured with the appropriate arguments for this task.
############################################
Example
>>> parser = YourTask.get_parser()
>>> args = parser.parse_args()
>>> print(args.output_dir)
################## NOTE The parser is built to include default values and help messages for each argument. It also includes a specific formatting class for displaying argument defaults in the help message.
Main entry point for the AbsTask class, responsible for parsing command line
arguments and initiating the training or evaluation process.
- Parameters:
- args (Optional *[*argparse.Namespace ]) – Parsed command line arguments. If None, a new parser is created and arguments are parsed from the command line.
- cmd (Optional *[*Sequence *[*str ] ]) – Command line arguments as a sequence. If None, the function uses sys.argv to parse arguments.
- Raises:
- RuntimeError – If the deprecated –pretrain_path argument is provided.
- RuntimeError – If required command arguments are missing.
- RuntimeError – If the dataset does not satisfy the requirements of the current task.
############################################
Example
>>> from espnet2.train.abs_task import AbsTask
>>> AbsTask.main()
################## NOTE This function also manages distributed training if specified in the command line arguments.
classmethod main_worker(args: Namespace)
Main worker function for executing the training or inference process.
This method is responsible for initializing the distributed process, setting up logging, building the model, and starting the training or inference loop. It can handle both single and multi-GPU setups and supports various configurations for training.
- Parameters:
- args (argparse.Namespace , optional) – The parsed command line arguments. If None, the arguments will be parsed from the command line.
- cmd (Sequence *[*str ] , optional) – A sequence of command line arguments. If None, the arguments will be parsed from sys.argv.
- Raises:RuntimeError – If the pretrain_path is specified as it is deprecated or if the model does not inherit from AbsESPnetModel.
############################################
Example
>>> from espnet2.train.abs_task import AbsTask
>>> AbsTask.main()
################## NOTE
- If the print_config argument is set to True, the function will print the configuration and exit.
- The function can handle distributed training and will set up the necessary parameters accordingly.
num_optimizers
abstract classmethod optional_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Define the optional names by Task.
This function is used by
cls.check_task_requirements()
If your model is defined as follows,
>>> from espnet2.train.abs_espnet_model import AbsESPnetModel
>>> class Model(AbsESPnetModel):
... def forward(self, input, output, opt=None): pass
then “optional_data_names” should be as
>>> optional_data_names = ('opt',)
- Parameters:
- train (bool) – Indicates whether the task is in training mode.
- inference (bool) – Indicates whether the task is in inference mode.
- Returns: A tuple containing the names of optional data.
- Return type: Tuple[str, …]
############################################
Example
>>> optional_data = Model.optional_data_names(train=True)
>>> print(optional_data)
('opt',)
classmethod print_config(file=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>) → None
Print the default configuration in YAML format.
This method retrieves the default configuration as a dictionary and prints it to the specified file or standard output in YAML format. It is particularly useful for verifying the configuration settings before starting a training session.
- Parameters:file – The file-like object where the configuration will be printed. Defaults to standard output (sys.stdout).
############################################
Example
>>> AbsTask.print_config()
# Prints the default configuration to standard output.
>>> with open('config.yaml', 'w') as f:
... AbsTask.print_config(file=f)
# Writes the default configuration to 'config.yaml'.
################## NOTE This method relies on the get_default_config() method to obtain the configuration settings.
abstract classmethod required_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Define the required names by Task.
This function is used by
cls.check_task_requirements()
If your model is defined as follows:
>>> from espnet2.train.abs_espnet_model import AbsESPnetModel
>>> class Model(AbsESPnetModel):
... def forward(self, input, output, opt=None): pass
then “required_data_names” should be as follows:
>>> required_data_names = ('input', 'output')
- Parameters:
- train (bool) – A flag indicating if the task is for training.
- inference (bool) – A flag indicating if the task is for inference.
- Returns: A tuple containing the required data names.
- Return type: Tuple[str, …]
- Raises:NotImplementedError – If the method is not implemented in the subclass.
trainer
alias of Trainer