espnet2.train.gan_trainer.GANTrainer
espnet2.train.gan_trainer.GANTrainer
class espnet2.train.gan_trainer.GANTrainer
Bases: Trainer
Trainer module for GAN-based training.
This class implements a trainer specifically designed for Generative Adversarial Networks (GANs). It is intended to be used with models that inherit from espnet.train.abs_gan_espnet_model.AbsGANESPnetModel. The GANTrainer manages the training process, including forward and backward passes for both the generator and discriminator networks.
generator_first
Indicates whether to update the generator first.
- Type: bool
skip_discriminator_prob
Probability of skipping the discriminator step.
Type: float
Parameters:args (argparse.Namespace) – Command line arguments parsed by argparse.
Returns: An instance of GANTrainerOptions containing the parsed options.
Return type:TrainerOptions
Raises:NotImplementedError – If certain options (accum_grad > 1 or grad_noise) are used.
########### Examples
To train a GAN using this trainer, you can create a model instance and call the training methods as follows:
``
`
python trainer = GANTrainer() trainer.train_one_epoch(model, iterator, optimizers, schedulers, scaler,
reporter, summary_writer, options, distributed_option)
``
`
You can also validate the model with:
``
`
python trainer.validate_one_epoch(model, iterator, reporter, options,
distributed_option)
``
`
NOTE
The GANTrainer requires a model that implements specific interfaces to handle the GAN training process correctly.
classmethod add_arguments(parser: ArgumentParser)
Add additional arguments for GAN-trainer.
This method extends the command-line argument parser with specific options for the GANTrainer. It allows the user to specify whether to update the generator first and the probability of skipping the discriminator step.
- Parameters:parser (argparse.ArgumentParser) – The argument parser to which the arguments will be added.
########### Examples
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> GANTrainer.add_arguments(parser)
>>> args = parser.parse_args(["--generator_first", "True"])
>>> print(args.generator_first) # Output: True
>>> print(args.skip_discriminator_prob) # Output: 0.0 (default)
classmethod build_options(args: Namespace) → TrainerOptions
Build options consumed by train(), eval(), and plot_attention().
This method constructs a set of options for the GANTrainer based on the provided command-line arguments. It creates an instance of GANTrainerOptions, which includes parameters specifically for GAN training.
- Parameters:args (argparse.Namespace) – The command-line arguments parsed.
- Returns: An instance of GANTrainerOptions populated with the specified arguments.
- Return type:TrainerOptions
########### Examples
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> parser.add_argument('--generator_first', type=str2bool, default=False)
>>> parser.add_argument('--skip_discriminator_prob', type=float, default=0.0)
>>> args = parser.parse_args()
>>> options = GANTrainer.build_options(args)
>>> print(options.generator_first)
False
>>> print(options.skip_discriminator_prob)
0.0
classmethod train_one_epoch(model: Module, iterator: Iterable[Tuple[List[str], Dict[str, Tensor]]], optimizers: Sequence[Optimizer], schedulers: Sequence[AbsScheduler | None], scaler: GradScaler | None, reporter: SubReporter, summary_writer, options: GANTrainerOptions, distributed_option: DistributedOption) → bool
Train one epoch.
This method performs a single epoch of training for the GAN model. It handles the forward and backward passes for both the generator and discriminator, applying optimizations and logging the training statistics.
- Parameters:
- model (torch.nn.Module) – The GAN model to be trained.
- iterator (Iterable *[*Tuple *[*List *[*str ] , Dict *[*str , torch.Tensor ] ] ]) – An iterable that provides batches of data for training.
- optimizers (Sequence *[*torch.optim.Optimizer ]) – A sequence of optimizers for the generator and discriminator.
- schedulers (Sequence *[*Optional [AbsScheduler ] ]) – A sequence of schedulers for adjusting the learning rate.
- scaler (Optional *[*GradScaler ]) – A GradScaler for mixed precision training.
- reporter (SubReporter) – An object for reporting training statistics.
- summary_writer – A writer for logging summaries (e.g., TensorBoard).
- options (GANTrainerOptions) – The options for the GAN training process.
- distributed_option (DistributedOption) – Options for distributed training.
- Returns: True if all steps in the epoch were invalid (i.e., no valid updates were made), otherwise False.
- Return type: bool
- Raises:
- NotImplementedError – If certain options like accum_grad or
- grad_noise –
########### Examples
>>> options = GANTrainerOptions(generator_first=True,
... skip_discriminator_prob=0.1)
>>> result = GANTrainer.train_one_epoch(model, iterator,
... optimizers, schedulers,
... scaler, reporter,
... summary_writer, options,
... distributed_option)
NOTE
This method assumes that the model has a forward method that returns a dictionary containing the loss and statistics.
classmethod validate_one_epoch(model: Module, iterator: Iterable[Dict[str, Tensor]], reporter: SubReporter, options: GANTrainerOptions, distributed_option: DistributedOption) → None
Validate one epoch.