espnet2.tasks.mt.MTTask
espnet2.tasks.mt.MTTask
class espnet2.tasks.mt.MTTask
Bases: AbsTask
MTTask is a class that handles the configuration and execution of machine
translation tasks. It inherits from the AbsTask class and provides methods for setting up task-specific arguments, building models, and processing data for training and inference.
num_optimizers
Number of optimizers to be used in the task. Default is 1.
- Type: int
class_choices_list
A list of class choices for various components (frontend, specaug, preencoder, encoder, postencoder, decoder, model).
- Type: list
trainer
The Trainer class used for modifying train or eval procedures.
Type:Trainer
Parameters:parser (argparse.ArgumentParser) – Argument parser to add task-related arguments.
Returns: A function to collate input data for training.
Return type: Callable
Yields:Optional[Callable] –
A function to preprocess input data based on the : provided arguments.
Raises:RuntimeError – If token_list or src_token_list is not a string or list.
################# Examples
To add task arguments: : parser = argparse.ArgumentParser() MTTask.add_task_arguments(parser)
To build a model: : args = parser.parse_args() model = MTTask.build_model(args)
To build a collate function: : collate_fn = MTTask.build_collate_fn(args, train=True)
######### NOTE Ensure to provide the correct token lists and model configurations for successful model building.
classmethod add_task_arguments(parser: ArgumentParser)
Adds task-specific arguments to the provided argument parser.
This method defines command-line arguments that are related to the task. It creates groups for task-related and preprocessing-related arguments, and it also incorporates choices for various components of the model architecture, such as frontend, encoder, decoder, and more.
- Parameters:parser (argparse.ArgumentParser) – The argument parser to which task arguments will be added.
################# Examples
parser = argparse.ArgumentParser() MTTask.add_task_arguments(parser)
######### NOTE The method is designed to allow flexible addition of command-line arguments while ensuring that the arguments are organized and documented properly. Some arguments are required, and their presence can be checked during runtime.
classmethod build_collate_fn(args: Namespace, train: bool) → Callable[[Collection[Tuple[str, Dict[str, ndarray]]]], Tuple[List[str], Dict[str, Tensor]]]
Build a collate function for the task.
This method constructs a callable function that can be used to collate batches of data. It is specifically designed for use during training and evaluation of the MTTask model.
- Parameters:
- args (argparse.Namespace) – The command-line arguments containing the configuration for the task.
- train (bool) – A flag indicating whether the function is being built for training or evaluation.
- Returns: Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], : > Tuple[List[str], Dict[str, torch.Tensor]]]: <br/> A callable that takes a collection of tuples containing string identifiers and dictionaries of numpy arrays, and returns a tuple containing a list of strings and a dictionary of PyTorch tensors.
################# Examples
>>> from espnet2.tasks.mt import MTTask
>>> args = argparse.Namespace(...)
>>> collate_fn = MTTask.build_collate_fn(args, train=True)
>>> batch = [
... ("example_1", {"input": np.array([1, 2, 3])}),
... ("example_2", {"input": np.array([4, 5, 6])}),
... ]
>>> collated_data = collate_fn(batch)
>>> print(collated_data)
######### NOTE The int value = 0 is reserved by the CTC-blank symbol.
classmethod build_model(args: Namespace) → ESPnetMTModel
Builds and initializes the model for the MTTask.
This method constructs the model based on the provided arguments, configures various components such as the frontend, encoder, decoder, and CTC, and initializes the model parameters.
- Parameters:args (argparse.Namespace) – The namespace containing all the arguments required to build the model. This includes paths to token lists, model configuration, and initialization settings.
- Returns: An instance of the ESPnetMTModel configured with the : specified components.
- Return type:ESPnetMTModel
- Raises:RuntimeError – If token_list or src_token_list is not a string or list.
################# Examples
Example of creating a model with specified arguments
parser = argparse.ArgumentParser() MTTask.add_task_arguments(parser) args = parser.parse_args() model = MTTask.build_model(args)
######### NOTE Ensure that the token lists are correctly specified and accessible before invoking this method.
classmethod build_preprocess_fn(args: Namespace, train: bool) → Callable[[str, Dict[str, array]], Dict[str, ndarray]] | None
Builds a preprocessing function for the task.
This function constructs a callable that can be used to preprocess input data based on the specified arguments. If preprocessing is enabled, it returns a MutliTokenizerCommonPreprocessor instance configured with the provided arguments. Otherwise, it returns None.
- Parameters:
- args (argparse.Namespace) – The command-line arguments containing configuration options for preprocessing.
- train (bool) – A flag indicating whether the function is being built for training or evaluation.
- Returns: Optional[Callable[[str, Dict[str, np.ndarray]], Dict[str, np.ndarray]]]: A preprocessing function if args.use_preprocessor is True, otherwise None.
################# Examples
>>> from argparse import Namespace
>>> args = Namespace(use_preprocessor=True, token_type='bpe',
... src_token_type='bpe', token_list='path/to/token_list.txt',
... src_token_list='path/to/src_token_list.txt',
... bpemodel='path/to/bpemodel',
... src_bpemodel='path/to/src_bpemodel',
... non_linguistic_symbols='path/to/symbols',
... cleaner='tacotron', g2p='g2p_method',
... tokenizer_encode_conf={},
... src_tokenizer_encode_conf={})
>>> preprocess_fn = MTTask.build_preprocess_fn(args, train=True)
>>> preprocessed_data = preprocess_fn("some text", {"key": np.array([1, 2, 3])})
class_choices_list
num_optimizers
classmethod optional_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
MTTask is a class that defines the task for Machine Translation (MT) within the
ESPnet framework. It manages the configuration and setup of various components needed for the MT task, including the frontend, encoder, decoder, and model.
num_optimizers
The number of optimizers used for training. Default is 1.
- Type: int
class_choices_list
A list of class choices for various components (frontend, specaug, preencoder, encoder, postencoder, decoder, model).
- Type: List[ClassChoices]
trainer
The class responsible for training procedures.
Type:Trainer
Parameters:parser (argparse.ArgumentParser) – The argument parser for command-line arguments.
Returns: A callable function for collating training data.
Return type: Callable
Yields: None
Raises:RuntimeError – If token_list or src_token_list is not a string or list.
################# Examples
Adding task arguments to the parser
MTTask.add_task_arguments(parser)
Building the collate function for data processing
collate_fn = MTTask.build_collate_fn(args, train=True)
Building the preprocessing function
preprocess_fn = MTTask.build_preprocess_fn(args, train=True)
Getting required data names for training
required_names = MTTask.required_data_names(train=True)
Getting optional data names for inference
optional_names = MTTask.optional_data_names(inference=True)
Building the model based on the configuration
model = MTTask.build_model(args)
######### NOTE This class is a subclass of AbsTask and provides additional functionality specific to machine translation tasks.
classmethod required_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Returns the required data names for training or inference.
This method determines the necessary data inputs for the task based on whether the operation is in training or inference mode. In training mode, both source and target texts are required, while in inference mode, only the source text is needed.
- Parameters:
- train (bool) – A flag indicating if the task is in training mode. Defaults to True.
- inference (bool) – A flag indicating if the task is in inference mode. Defaults to False.
- Returns: A tuple containing the names of the required data. : Returns (“src_text”, “text”) in training mode and (“src_text”,) in inference mode.
- Return type: Tuple[str, …]
################# Examples
>>> MTTask.required_data_names(train=True, inference=False)
('src_text', 'text')
>>> MTTask.required_data_names(train=False, inference=True)
('src_text',)
trainer
alias of Trainer