espnet2.tasks.s2t.S2TTask
espnet2.tasks.s2t.S2TTask
class espnet2.tasks.s2t.S2TTask
Bases: AbsTask
S2TTask is a class that defines the sequence-to-text (S2T) task for training
and evaluating models in the ESPnet framework. It inherits from the AbsTask class and provides methods to manage task-specific configurations, data processing, and model building.
num_optimizers
The number of optimizers used in training.
- Type: int
class_choices_list
A list of class choices for different components such as frontend, encoder, decoder, etc.
- Type: list
trainer
The trainer class used for training and evaluation.
- Type:Trainer
add_task_arguments(parser
argparse.ArgumentParser): Adds task-related arguments to the provided argument parser.
build_collate_fn(args
argparse.Namespace, train: bool) -> Callable: Builds a collate function for batching data.
build_preprocess_fn(args
argparse.Namespace, train: bool) -> Optional[Callable]: Builds a preprocessing function for input data.
required_data_names(train
bool = True, inference: bool = False) -> Tuple[str, …]: Returns the names of the required data for the task.
optional_data_names(train
bool = True, inference: bool = False) -> Tuple[str, …]: Returns the names of the optional data for the task.
build_model(args
argparse.Namespace) -> ESPnetS2TModel: Constructs and returns the ESPnet S2T model based on the given arguments.
################# Examples
Example of adding task arguments
import argparse parser = argparse.ArgumentParser() S2TTask.add_task_arguments(parser)
Example of building a model
args = parser.parse_args() model = S2TTask.build_model(args)
########## NOTE This class is intended to be used as part of the ESPnet framework for sequence-to-text tasks.
classmethod add_task_arguments(parser: ArgumentParser)
Adds task-related arguments to the provided argument parser.
This method is responsible for defining and adding various command-line arguments that are specific to the S2T task. These arguments can include options for preprocessing, model configuration, and other task-specific parameters.
- Parameters:parser (argparse.ArgumentParser) – The argument parser to which the task-related arguments will be added.
################# Examples
To add task arguments to a parser, you can use:
``
`
python import argparse from s2t_task import S2TTask
parser = argparse.ArgumentParser() S2TTask.add_task_arguments(parser) args = parser.parse_args()
``
`
########## NOTE This method modifies the parser to include arguments necessary for configuring the S2T task, such as –token_list, –input_size, and others. The default values for some arguments are defined based on the class-level attributes.
classmethod build_collate_fn(args: Namespace, train: bool) → Callable[[Collection[Tuple[str, Dict[str, ndarray]]]], Tuple[List[str], Dict[str, Tensor]]]
Build a collate function for batching input data.
This function creates a collate function that will be used to prepare batches of input data for training or evaluation. It ensures that the input sequences are padded correctly to maintain consistent input shapes.
- Parameters:
- args (argparse.Namespace) – The parsed command-line arguments.
- train (bool) – A flag indicating whether the function is being used for training or evaluation.
- Returns: Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], : > Tuple[List[str], Dict[str, torch.Tensor]]]: <br/> A collate function that processes batches of input data.
########## NOTE The integer value 0 is reserved by the CTC blank symbol.
################# Examples
>>> collate_fn = S2TTask.build_collate_fn(args, train=True)
>>> batch = collate_fn(data)
>>> print(batch) # Output will depend on the input data structure
classmethod build_model(args: Namespace) → ESPnetS2TModel
Builds the S2T model based on the provided arguments.
This method constructs an instance of the ESPnetS2TModel by assembling various components such as frontend, encoder, decoder, and CTC based on the configuration specified in the args argument. It also handles token list loading and initialization of model parameters.
- Parameters:args (argparse.Namespace) – The arguments containing configuration parameters for model construction. This includes options for the frontend, encoder, decoder, and other components.
- Returns: An instance of the constructed S2T model.
- Return type:ESPnetS2TModel
- Raises:RuntimeError – If token_list is neither a string nor a list.
################# Examples
Example of building a model with specific arguments
args = argparse.Namespace(
token_list=’path/to/token_list.txt’, input_size=None, frontend=’default’, encoder=’transformer’, decoder=’lightweight_conv’, ctc_conf={‘some_param’: ‘value’}, model_conf={}
) model = S2TTask.build_model(args)
########## NOTE The function expects the token_list argument to be either a path to a text file or a list of tokens. The method will raise a RuntimeError if the token_list is not in the expected format.
classmethod build_preprocess_fn(args: Namespace, train: bool) → Callable[[str, Dict[str, array]], Dict[str, ndarray]] | None
Builds a preprocessing function based on the provided arguments.
This method checks if preprocessing should be applied based on the use_preprocessor argument. If true, it initializes a preprocessor class using the specified configuration and returns a callable function that processes the input data.
- Parameters:
- cls – The class that this method belongs to.
- args (argparse.Namespace) – The arguments parsed from the command line.
- train (bool) – A flag indicating whether the function is being built for training or evaluation.
- Returns: A function that takes an input string and a dictionary of features, returning a processed dictionary of features, or None if preprocessing is not enabled.
- Return type: Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]
- Raises:
- AttributeError – If the specified preprocessor class is not found.
- Exception – If any other exception occurs during initialization of the preprocessor.
################# Examples
To use this function in a training scenario, you might do the following:
python preprocess_fn = S2TTask.build_preprocess_fn(args, train=True) processed_data = preprocess_fn(input_string, features_dict)
########## NOTE The method expects that the args object has attributes corresponding to the preprocessing options, including token_type, token_list, and various noise and RIR parameters.
class_choices_list
num_optimizers
classmethod optional_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Class representing the Speech-to-Text (S2T) task.
This class provides methods to configure the S2T task, including the addition of task-specific arguments, building models, and handling data processing.
num_optimizers
The number of optimizers to be used. Default is 1.
- Type: int
class_choices_list
List of class choices for various components such as frontend, encoder, decoder, etc.
- Type: List[ClassChoices]
trainer
The trainer class to be used for training.
Type:Trainer
Parameters:parser (argparse.ArgumentParser) – The argument parser to add task-related arguments.
Returns: None
Yields: None
Raises:RuntimeError – If the token list is not a string or a list.
################# Examples
To add task arguments:
python parser = argparse.ArgumentParser() S2TTask.add_task_arguments(parser)
To build a model:
python args = parser.parse_args() model = S2TTask.build_model(args)
To retrieve required data names:
python data_names = S2TTask.required_data_names(train=True)
To retrieve optional data names:
python optional_data_names = S2TTask.optional_data_names(train=True)
########## NOTE The optional data names are logged for reference.
classmethod required_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Returns the required data names for training or inference.
This method determines the data names needed for the task based on whether it is in training or inference mode. In training mode, both ‘speech’ and ‘text’ data are required. In inference mode, only ‘speech’ data is required.
- Parameters:
- train (bool) – A flag indicating whether the method is called during training. Defaults to True.
- inference (bool) – A flag indicating whether the method is called during inference. Defaults to False.
- Returns: A tuple containing the names of the required data. : For training, returns (‘speech’, ‘text’). For inference, returns (‘speech’,).
- Return type: Tuple[str, …]
################# Examples
>>> S2TTask.required_data_names(train=True, inference=False)
('speech', 'text')
>>> S2TTask.required_data_names(train=False, inference=True)
('speech',)
trainer
alias of Trainer