espnet2.tasks.st.STTask
espnet2.tasks.st.STTask
class espnet2.tasks.st.STTask
Bases: AbsTask
STTask is a class for speech translation tasks that extends the AbsTask
class. It manages the configuration and initialization of various components needed for speech translation, including frontends, encoders, decoders, and preprocessors.
num_optimizers
The number of optimizers to use. Default is 1.
- Type: int
class_choices_list
A list of class choices for various components (frontend, encoder, decoder, etc.).
- Type: list
trainer
The Trainer class used for training and evaluation.
Type:Trainer
Parameters:parser (argparse.ArgumentParser) – The argument parser to add task-specific arguments.
Returns: A function that collates training data into batches.
Return type: Callable
Yields:Optional[Callable] – A function that preprocesses the input data.
Raises:RuntimeError – If the token_list is not a string or list.
################# Examples
To add task-specific arguments to a parser
parser = argparse.ArgumentParser() STTask.add_task_arguments(parser)
To build a model based on provided arguments
args = parser.parse_args() model = STTask.build_model(args)
########## NOTE If you need to modify the training or evaluation procedures, change the Trainer class in this task.
classmethod add_task_arguments(parser: ArgumentParser)
Adds task-related arguments to the argument parser for the STTask.
This method defines command-line arguments that are specific to the speech translation task, including options for token lists, model initialization, preprocessing, and more. The arguments added by this method allow users to customize their training and evaluation runs.
- Parameters:parser (argparse.ArgumentParser) – The argument parser to which the task-related arguments will be added.
################# Examples
To use this method, you can set up an argument parser as follows:
``
`
python import argparse from st_task import STTask
parser = argparse.ArgumentParser(description=”Speech Translation Task”) STTask.add_task_arguments(parser) args = parser.parse_args()
``
`
########## NOTE This method will not enforce required arguments when the –print_config mode is used. Instead, it appends required arguments to a default list.
- Raises:Any exceptions raised during argument parsing will be propagated. –
classmethod build_collate_fn(args: Namespace, train: bool) → Callable[[Collection[Tuple[str, Dict[str, ndarray]]]], Tuple[List[str], Dict[str, Tensor]]]
Builds a collate function for processing batches of data during training or
evaluation.
This method constructs a collate function that can handle the data structure provided by the data loader. The collate function pads sequences to the same length and prepares them for model input. It is particularly useful when dealing with variable-length sequences.
- Parameters:
- args (argparse.Namespace) – The command-line arguments containing configurations for the task.
- train (bool) – A flag indicating whether the function is being used for training (True) or evaluation (False).
- Returns: Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[List[str], Dict[str, torch.Tensor]]]: A callable collate function that takes a collection of data and returns a tuple containing a list of identifiers and a dictionary of tensors ready for input into the model.
################# Examples
>>> from some_module import STTask
>>> args = argparse.Namespace()
>>> args.some_config = "value"
>>> collate_fn = STTask.build_collate_fn(args, train=True)
>>> batch_data = [("id1", {"feature": np.array([1, 2])}),
... ("id2", {"feature": np.array([3, 4, 5])})]
>>> identifiers, tensor_data = collate_fn(batch_data)
########## NOTE The integer value 0 is reserved for the CTC-blank symbol.
classmethod build_model(args: Namespace) → ESPnetSTModel
Build the speech translation model based on the provided arguments.
This method constructs the model by assembling various components, including frontends, encoders, decoders, and normalization layers, as specified in the arguments. It reads token lists from files if provided as strings and logs the vocabulary sizes. The model configuration is flexible, allowing for the inclusion of optional components such as extra decoders and encoders.
- Parameters:args (argparse.Namespace) – The command-line arguments containing model configuration options.
- Returns: An instance of the speech translation model.
- Return type: Union[ESPnetSTModel]
- Raises:RuntimeError – If the token list is not provided as a string or list.
################# Examples
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> STTask.add_task_arguments(parser)
>>> args = parser.parse_args(["--token_list", "tokens.txt"])
>>> model = STTask.build_model(args)
########## NOTE The method assumes that the input arguments are properly configured and validated. It is recommended to use the provided argument parser to ensure the correct setup of the model.
classmethod build_preprocess_fn(args: Namespace, train: bool) → Callable[[str, Dict[str, array]], Dict[str, ndarray]] | None
Builds a preprocessing function for the input data.
This method constructs a preprocessing function that will be used to preprocess the input data based on the provided arguments. If the use_preprocessor argument is set to True, it initializes the MutliTokenizerCommonPreprocessor with the relevant settings.
- Parameters:
- args (argparse.Namespace) – The arguments namespace containing configuration options for preprocessing.
- train (bool) – A flag indicating whether the function is being built for training or evaluation.
- Returns: A callable preprocessing function that processes the input data or None if preprocessing is not to be used.
- Return type: Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]
- Raises:
- AttributeError – If the preprocessor attribute is not found in the args.
- Exception – Any other exceptions that may arise during the preprocessing function creation.
################# Examples
>>> args = argparse.Namespace()
>>> args.use_preprocessor = True
>>> args.token_type = "bpe"
>>> preprocess_fn = STTask.build_preprocess_fn(args, train=True)
>>> result = preprocess_fn("input_text", {"key": np.array([1, 2, 3])})
class_choices_list
num_optimizers
classmethod optional_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
A class representing the Speech Translation Task, inheriting from AbsTask.
This class is responsible for setting up the speech translation task, including the configuration of models, preprocessing, and collate functions. It provides methods to handle command-line arguments specific to the task, build the necessary models, and manage data input/output.
num_optimizers
Number of optimizers to be used in training.
- Type: int
class_choices_list
List of class choices for various components such as frontend, encoder, decoder, etc.
- Type: list
trainer
The trainer class used for training the model.
Type:Trainer
Parameters:parser (argparse.ArgumentParser) – Argument parser to add task-related command-line arguments.
Returns: None
Yields: None
Raises:RuntimeError – If token_list is not a string or list during model building.
################# Examples
To add task arguments to an ArgumentParser
parser = argparse.ArgumentParser() STTask.add_task_arguments(parser)
To build a collate function for data processing
collate_fn = STTask.build_collate_fn(args, train=True)
To build a preprocess function for input data
preprocess_fn = STTask.build_preprocess_fn(args, train=True)
To get the required data names for training
required_names = STTask.required_data_names(train=True)
To get the optional data names for inference
optional_names = STTask.optional_data_names(inference=True)
To build the model using provided arguments
model = STTask.build_model(args)
########## NOTE The add_task_arguments method allows you to configure various parameters related to the task, including token lists, model configurations, and preprocessing options.
classmethod required_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Get the required data names for the STTask.
This method returns a tuple of required data names based on the task mode (training or inference). The required data typically includes the input speech and corresponding text data for training, while only the input speech is required for inference.
- Parameters:
- train (bool) – Indicates whether the task is in training mode. Defaults to True.
- inference (bool) – Indicates whether the task is in inference mode. Defaults to False.
- Returns: A tuple containing the required data names. : For training, it returns (“speech”, “text”); for inference, it returns (“speech”,).
- Return type: Tuple[str, …]
################# Examples
>>> STTask.required_data_names(train=True, inference=False)
('speech', 'text')
>>> STTask.required_data_names(train=False, inference=True)
('speech',)
########## NOTE This method is a class method and should be called on the STTask class itself.
trainer
alias of Trainer