espnet2.tasks.gan_tts.GANTTSTask
espnet2.tasks.gan_tts.GANTTSTask
class espnet2.tasks.gan_tts.GANTTSTask
Bases: AbsTask
GAN-based text-to-speech task.
This class implements a GAN-based approach for text-to-speech (TTS) synthesis. It includes functionalities for building models, managing data, and processing input features, all tailored for generative adversarial networks.
num_optimizers
Number of optimizers required by GAN (default is 2).
- Type: int
class_choices_list
A list of class choices for various components such as feature extraction, normalization, and TTS models.
- Type: list
trainer
Specifies the trainer class to be used (GANTrainer).
- Type: type
add_task_arguments(parser
argparse.ArgumentParser): Adds task-specific arguments to the argument parser.
build_collate_fn(args
argparse.Namespace, train: bool) -> Callable: Builds a collate function for batching data.
build_preprocess_fn(args
argparse.Namespace, train: bool) -> Optional[Callable]: Constructs a preprocessing function based on the input arguments.
required_data_names(train
bool = True, inference: bool = False) -> Tuple[str, …]: Returns a tuple of required data names for training or inference.
optional_data_names(train
bool = True, inference: bool = False) -> Tuple[str, …]: Returns a tuple of optional data names for training or inference.
build_model(args
argparse.Namespace) -> ESPnetGANTTSModel: Builds the ESPnet GAN TTS model based on the provided arguments.
build_optimizers(args
argparse.Namespace, model: ESPnetGANTTSModel) -> List[torch.optim.Optimizer]: Constructs the optimizers for the model.
################### Examples
To add task arguments: : parser = argparse.ArgumentParser() GANTTSTask.add_task_arguments(parser)
To build the model: : args = parser.parse_args() model = GANTTSTask.build_model(args)
To build optimizers: : optimizers = GANTTSTask.build_optimizers(args, model)
######### NOTE The class uses a combination of various components including feature extractors, normalization layers, and different TTS models to create a complete TTS pipeline.
classmethod add_task_arguments(parser: ArgumentParser)
Adds task-related arguments to the argument parser for the GANTTSTask.
This method is responsible for defining the command line arguments that are specific to the GANTTSTask, including configurations for feature extraction, normalization, and text-to-speech models. It organizes the arguments into groups for better clarity and structure.
- Parameters:
- cls – The class itself (used for class method).
- parser (argparse.ArgumentParser) – The argument parser to which the task-related arguments will be added.
######### NOTE The function modifies the parser directly and adds several required and optional arguments that are essential for the GANTTSTask.
################### Examples
To use the arguments added by this method, you might do:
``
`
python import argparse from gantts_task import GANTTSTask
parser = argparse.ArgumentParser() GANTTSTask.add_task_arguments(parser) args = parser.parse_args()
``
`
- Raises:ValueError – If the argument parsing fails or if required arguments are not provided.
classmethod build_collate_fn(args: Namespace, train: bool) → Callable[[Collection[Tuple[str, Dict[str, ndarray]]]], Tuple[List[str], Dict[str, Tensor]]]
Builds a collate function for batching input data.
This method creates a callable that can be used to collate input data into batches for training or evaluation. It handles padding for different types of input data.
- Parameters:
- args (argparse.Namespace) – The arguments parsed from the command line, including configuration options.
- train (bool) – A flag indicating whether the function is being built for training or evaluation.
- Returns: Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]], : > Tuple[List[str], Dict[str, torch.Tensor]]]]: <br/> A callable that takes a collection of input data tuples and returns a batch of input data.
################### Examples
>>> collate_fn = GANTTSTask.build_collate_fn(args, train=True)
>>> batch = collate_fn(data)
######### NOTE The collate function will pad the sequences to the maximum length within the batch. It will ignore padding for specified fields, such as “spembs”, “sids”, and “lids”.
classmethod build_model(args: Namespace) → ESPnetGANTTSModel
Build the model for the GAN-based text-to-speech task.
This method constructs the ESPnetGANTTSModel using the provided arguments. It initializes various components of the model such as feature extractors, normalization layers, and the TTS module based on the input arguments.
- Parameters:args (argparse.Namespace) – The parsed command-line arguments containing configuration options for building the model.
- Returns: An instance of the ESPnetGANTTSModel constructed : with the specified configuration.
- Return type:ESPnetGANTTSModel
- Raises:RuntimeError – If the token_list argument is not of type str or dict.
################### Examples
>>> from argparse import Namespace
>>> args = Namespace(
... token_list="path/to/token_list.txt",
... odim=None,
... feats_extract="log_spectrogram",
... feats_extract_conf={},
... normalize="global_mvn",
... normalize_conf={},
... tts="vits",
... tts_conf={},
... pitch_extract=None,
... energy_extract=None,
... pitch_normalize=None,
... energy_normalize=None,
... model_conf={}
... )
>>> model = GANTTSTask.build_model(args)
>>> print(model)
classmethod build_optimizers(args: Namespace, model: ESPnetGANTTSModel) → List[Optimizer]
Builds the optimizers for the GAN-based text-to-speech model.
This method initializes two optimizers: one for the generator and one for the discriminator. The optimizers are configured based on the arguments provided in args and the model structure.
- Parameters:
- args (argparse.Namespace) – The arguments namespace containing configuration settings, such as optimizer types and their respective configurations.
- model (ESPnetGANTTSModel) – The GAN-based text-to-speech model which contains both the generator and discriminator.
- Returns: A list containing the initialized : optimizers for the generator and discriminator.
- Return type: List[torch.optim.Optimizer]
- Raises:
- ValueError – If the specified optimizer class is not found in the available optimizer classes.
- RuntimeError – If fairscale is required but not installed.
################### Examples
>>> from espnet2.train.class_choices import ClassChoices
>>> args = argparse.Namespace(
... optim='adam',
... optim_conf={'lr': 0.001},
... optim2='sgd',
... optim2_conf={'lr': 0.01},
... sharded_ddp=False,
... )
>>> model = ESPnetGANTTSModel(...)
>>> optimizers = GANTTSTask.build_optimizers(args, model)
>>> assert len(optimizers) == 2 # One for generator and one for discriminator
classmethod build_preprocess_fn(args: Namespace, train: bool) → Callable[[str, Dict[str, array]], Dict[str, ndarray]] | None
Builds a preprocessing function based on the provided arguments.
This function returns a callable that can be used to preprocess input data before it is fed into the model. If preprocessing is disabled, it returns None.
- Parameters:
- args (argparse.Namespace) – The parsed arguments containing configurations for preprocessing.
- train (bool) – A flag indicating whether the function is for training or inference.
- Returns: Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: A preprocessing function if args.use_preprocessor is True, otherwise None.
################### Examples
>>> from argparse import Namespace
>>> args = Namespace(use_preprocessor=True, token_type='phn',
... token_list='path/to/token_list', bpemodel=None,
... non_linguistic_symbols=None, cleaner=None, g2p=None)
>>> preprocess_fn = GANTTSTask.build_preprocess_fn(args, train=True)
>>> processed_data = preprocess_fn("sample_text",
... {"additional_data": np.array([1, 2, 3])})
######### NOTE This method relies on the CommonPreprocessor for the actual preprocessing logic. Ensure that the necessary configurations are provided in the args.
class_choices_list
num_optimizers
classmethod optional_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
GAN-based text-to-speech task.
This class defines a GANTTSTask for generating speech from text using GANs. It includes methods for argument parsing, building models, optimizers, and processing data for training and inference.
num_optimizers
The number of optimizers required for GAN training.
- Type: int
class_choices_list
A list of class choices for feature extraction, normalization, and TTS methods.
- Type: List[ClassChoices]
trainer
The trainer class used for training.
Type: Type[GANTrainer]
Parameters:parser (argparse.ArgumentParser) – The argument parser for the task.
Returns: None
Yields: None
Raises:
- RuntimeError – If token_list is not of type str or dict.
- ValueError – If optimizer type is not recognized.
################### Examples
To add task arguments to the parser:
python parser = argparse.ArgumentParser() GANTTSTask.add_task_arguments(parser)
To build the model with specific arguments:
python model = GANTTSTask.build_model(args)
To create optimizers for the model:
python optimizers = GANTTSTask.build_optimizers(args, model)
######### NOTE The task uses two optimizers for the generator and discriminator.
classmethod required_data_names(train: bool = True, inference: bool = False) → Tuple[str, ...]
Get the required data names for the GANTTS task.
This method returns a tuple of required data names based on whether the task is in training or inference mode. When in training mode, both “text” and “speech” data are required. In inference mode, only “text” data is required.
- Parameters:
- train (bool , optional) – Indicates if the task is in training mode. Defaults to True.
- inference (bool , optional) – Indicates if the task is in inference mode. Defaults to False.
- Returns: A tuple containing the required data names.
- Return type: Tuple[str, …]
################### Examples
>>> GANTTSTask.required_data_names(train=True, inference=False)
('text', 'speech')
>>> GANTTSTask.required_data_names(train=False, inference=True)
('text',)
trainer
alias of GANTrainer