espnet2.train.uasr_trainer.UASRTrainer
espnet2.train.uasr_trainer.UASRTrainer
class espnet2.train.uasr_trainer.UASRTrainer
Bases: Trainer
Trainer for GAN-based UASR training.
This class implements a trainer specifically designed for training unsupervised automatic speech recognition (UASR) models using Generative Adversarial Networks (GANs). To use this trainer, the model must inherit from espnet.train.abs_gan_espnet_model.AbsGANESPnetModel.
generator_first
Flag indicating whether to update the generator first.
- Type: bool
max_num_warning
Maximum number of warnings displayed during training.
Type: int
Parameters:args (argparse.Namespace) – Command line arguments for configuring the trainer.
Returns: A dataclass containing options for training.
Return type:TrainerOptions
Raises:NotImplementedError – If accum_grad > 1 or grad_noise is enabled, as these options are not supported in GAN-based training.
########### Examples
To use this trainer, you can add arguments as follows:
python parser = argparse.ArgumentParser() UASRTrainer.add_arguments(parser) args = parser.parse_args() options = UASRTrainer.build_options(args)
You can then train the model using:
``
`
python UASRTrainer.train_one_epoch(model, iterator, optimizers, schedulers,
scaler, reporter, summary_writer, options, distributed_option)
``
`
Validation can be performed using:
``
`
python UASRTrainer.validate_one_epoch(model, iterator, reporter, options,
distributed_option)
``
`
######## NOTE Ensure that the model being trained adheres to the expected input/output formats, particularly the output being a tuple or list containing the loss, statistics, weight, and any other necessary values.
classmethod add_arguments(parser: ArgumentParser)
Add additional arguments for GAN-trainer.
This method extends the argument parser with options specific to the GAN-based UASR training process. It allows users to specify whether the generator should be updated first and to set a maximum number of warnings to display during training.
- Parameters:parser (argparse.ArgumentParser) – The argument parser to which the additional arguments will be added.
########### Examples
To use this method, you can create an argument parser and call add_arguments like so:
``
`
python import argparse from uasr_trainer import UASRTrainer
parser = argparse.ArgumentParser() UASRTrainer.add_arguments(parser) args = parser.parse_args()
``
`
######## NOTE The –generator_first argument accepts a boolean value, which indicates whether to update the generator before the discriminator during training.
The –max_num_warning argument specifies the maximum number of warnings that will be displayed. If this limit is reached, further warnings will be suppressed.
classmethod build_options(args: Namespace) → TrainerOptions
Build options consumed by train(), eval(), and plot_attention().
This method constructs a set of training options based on the provided arguments. It is designed to be used for setting up parameters required during the training, evaluation, and attention plotting processes.
- Parameters:args (argparse.Namespace) – The arguments parsed from the command line.
- Returns: An instance of TrainerOptions populated with UASRTrainerOptions.
- Return type:TrainerOptions
########### Examples
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> parser.add_argument("--generator_first", type=str2bool, default=False)
>>> parser.add_argument("--max_num_warning", type=int, default=10)
>>> args = parser.parse_args()
>>> options = UASRTrainer.build_options(args)
######## NOTE This method requires that the input arguments are compatible with the expected UASRTrainerOptions fields.
classmethod train_one_epoch(model: Module, iterator: Iterable[Tuple[List[str], Dict[str, Tensor]]], optimizers: Sequence[Optimizer], schedulers: Sequence[AbsScheduler | None], scaler: GradScaler | None, reporter: SubReporter, summary_writer, options: UASRTrainerOptions, distributed_option: DistributedOption) → bool
Train one epoch for UASR.
This method performs a single training epoch for the UASR model. It iterates through the provided data iterator, applies forward and backward passes, and updates the model parameters accordingly. The method supports gradient clipping and distributed training.
- Parameters:
- model (torch.nn.Module) – The UASR model to be trained.
- iterator (Iterable *[*Tuple *[*List *[*str ] , Dict *[*str , torch.Tensor ] ] ]) – An iterable that provides batches of training data.
- optimizers (Sequence *[*torch.optim.Optimizer ]) – A sequence of optimizers for the model’s parameters.
- schedulers (Sequence *[*Optional [AbsScheduler ] ]) – A sequence of learning rate schedulers for the optimizers.
- scaler (Optional *[*GradScaler ]) – A GradScaler instance for mixed precision training.
- reporter (SubReporter) – An instance to report training metrics.
- summary_writer – A writer for logging summaries.
- options (UASRTrainerOptions) – Options for the training process.
- distributed_option (DistributedOption) – Options related to distributed training.
- Returns: Returns True if all steps in the epoch were invalid (e.g., due to gradient norms being NaN or inf), otherwise returns False.
- Return type: bool
- Raises:
- NotImplementedError – If accum_grad > 1 or grad_noise is set, as
- these options are not supported in GAN-based training. –
- RuntimeError – If the model output is not a tuple or list.
########### Examples
>>> from uasr_trainer import UASRTrainer
>>> trainer = UASRTrainer()
>>> model = ... # Initialize your model
>>> iterator = ... # Create your data iterator
>>> optimizers = [...] # List of optimizers
>>> schedulers = [...] # List of schedulers
>>> scaler = ... # Initialize GradScaler if using mixed precision
>>> reporter = ... # Create a SubReporter instance
>>> options = UASRTrainerOptions(...) # Initialize training options
>>> distributed_option = DistributedOption(...) # Distributed options
>>> trainer.train_one_epoch(model, iterator, optimizers, schedulers,
... scaler, reporter, summary_writer, options,
... distributed_option)
######## NOTE Ensure that the model being trained inherits from espnet.train.abs_gan_espnet_model.AbsGANESPnetModel.
classmethod validate_one_epoch(model: Module, iterator: Iterable[Dict[str, Tensor]], reporter: SubReporter, options: UASRTrainerOptions, distributed_option: DistributedOption) → None
Validate one epoch.