espnet2.train.trainer.Trainer
espnet2.train.trainer.Trainer
class espnet2.train.trainer.Trainer
Bases: object
Trainer module for managing the training process of ESPnet models.
The Trainer class provides an interface for training models with various options, optimizers, and schedulers. It allows for customization through inheritance and overriding methods. The training process includes logging, validation, and attention plotting.
Example
To create a custom trainer with two optimizers, inherit from Trainer and override the necessary methods:
>>> class TwoOptimizerTrainer(Trainer):
... @classmethod
... def add_arguments(cls, parser):
... ...
...
... @classmethod
... def train_one_epoch(cls, model, optimizers, ...):
... loss1 = model.model1(...)
... loss1.backward()
... optimizers[0].step()
...
... loss2 = model.model2(...)
... loss2.backward()
... optimizers[1].step()
None
- Parameters:None
- Returns: None
- Raises:RuntimeError – If an instance of this class is created directly.
####### NOTE This class cannot be instantiated directly; it is intended to be subclassed.
classmethod add_arguments(parser: ArgumentParser)
Reserved for future development of another Trainer.
This method is intended to allow subclasses to define their own command-line arguments for training configurations. It is currently a placeholder and does not implement any functionality.
- Parameters:parser (argparse.ArgumentParser) – The argument parser instance to which the arguments should be added.
##########
Example
>>> import argparse
>>> class MyTrainer(Trainer):
... @classmethod
... def add_arguments(cls, parser):
... parser.add_argument('--my_arg', type=int, default=42)
classmethod build_options(args: Namespace) → TrainerOptions
Build options consumed by train(), eval(), and plot_attention().
This method constructs a TrainerOptions dataclass from the given arguments. The options include various parameters that control the training process, such as the number of GPUs, whether to use automatic mixed precision (AMP), gradient clipping, logging intervals, and more.
- Parameters:args (argparse.Namespace) – Command line arguments parsed using argparse.
- Returns: A dataclass instance containing the training : options.
- Return type:TrainerOptions
##########
Example
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> parser.add_argument('--ngpu', type=int, default=1)
>>> parser.add_argument('--resume', action='store_true')
>>> args = parser.parse_args(['--ngpu', '2', '--resume'])
>>> options = Trainer.build_options(args)
>>> print(options.ngpu) # Output: 2
>>> print(options.resume) # Output: True
classmethod plot_attention(model: Module, output_dir: Path | None, summary_writer, iterator: Iterable[Tuple[List[str], Dict[str, Tensor]]], reporter: SubReporter, options: TrainerOptions) → None
static resume(checkpoint: str | Path, model: Module, reporter: Reporter, optimizers: Sequence[Optimizer], schedulers: Sequence[AbsScheduler | None], scaler: GradScaler | None, ngpu: int = 0, strict: bool = True)
Trainer module.
This module provides the Trainer class for managing the training of machine learning models, specifically in the context of deep learning using PyTorch. The Trainer class includes methods for building options, resuming training from checkpoints, and executing the training and validation processes. It is designed to be extensible, allowing subclasses to implement custom training logic.
ngpu
Number of GPUs to use for training.
- Type: int
resume
Flag indicating whether to resume training from a checkpoint.
- Type: bool
use_amp
Flag to indicate if Automatic Mixed Precision is used.
- Type: bool
train_dtype
Data type for training.
- Type: str
grad_noise
Flag to indicate if gradient noise is added.
- Type: bool
accum_grad
Number of batches to accumulate gradients.
- Type: int
grad_clip
Maximum value for gradient clipping.
- Type: float
grad_clip
Type of norm for gradient clipping.
- Type: float
log_interval
Interval for logging training metrics.
- Type: Optional[int]
no_forward_run
Flag to skip forward run.
- Type: bool
use_matplotlib
Flag to indicate if Matplotlib is used for plotting.
- Type: bool
use_tensorboard
Flag to indicate if TensorBoard is used for logging.
- Type: bool
use_wandb
Flag to indicate if Weights & Biases is used for logging.
- Type: bool
adapter
Adapter type to be used.
- Type: str
use_adapter
Flag to indicate if an adapter is used.
- Type: bool
save_strategy
Strategy for saving models.
- Type: str
output_dir
Directory for saving output files.
- Type: Union[Path, str]
max_epoch
Maximum number of epochs for training.
- Type: int
seed
Random seed for reproducibility.
- Type: int
sharded_ddp
Flag to indicate if sharded DDP is used.
- Type: bool
patience
Patience for early stopping.
- Type: Optional[int]
keep_nbest_models
Number of best models to keep.
- Type: Union[int, List[int]]
nbest_averaging_interval
Interval for averaging n-best models.
- Type: int
early_stopping_criterion
Criteria for early stopping.
- Type: Sequence[str]
best_model_criterion
Criteria for determining the best model.
- Type: Sequence[Sequence[str]]
val_scheduler_criterion
Criteria for the validation scheduler.
- Type: Sequence[str]
unused_parameters
Flag to indicate if unused parameters are ignored.
- Type: bool
wandb_model_log_interval
Interval for logging model with Weights & Biases.
- Type: int
create_graph_in_tensorboard
Flag to indicate if graph is created in TensorBoard.
- Type: bool
##########
Example
>>> class TwoOptimizerTrainer(Trainer):
... @classmethod
... def add_arguments(cls, parser):
... ...
...
... @classmethod
... def train_one_epoch(cls, model, optimizers, ...):
... loss1 = model.model1(...)
... loss1.backward()
... optimizers[0].step()
...
... loss2 = model.model2(...)
... loss2.backward()
... optimizers[1].step()
classmethod run(model: AbsESPnetModel, optimizers: Sequence[Optimizer], schedulers: Sequence[AbsScheduler | None], train_iter_factory: AbsIterFactory, valid_iter_factory: AbsIterFactory, plot_attention_iter_factory: AbsIterFactory | None, trainer_options, distributed_option: DistributedOption) → None
Perform training.
This method executes the main process of training a model, which includes training and validation for multiple epochs. It handles various aspects of training such as gradient clipping, mixed precision, logging, and model saving.
- Parameters:
- model (AbsESPnetModel) – The model to be trained.
- optimizers (Sequence *[*torch.optim.Optimizer ]) – A sequence of optimizers to be used for training.
- schedulers (Sequence *[*Optional [AbsScheduler ] ]) – A sequence of schedulers to adjust the learning rate.
- train_iter_factory (AbsIterFactory) – Factory to create training data iterators.
- valid_iter_factory (AbsIterFactory) – Factory to create validation data iterators.
- plot_attention_iter_factory (Optional [AbsIterFactory ]) – Factory to create data iterators for plotting attention, if any.
- trainer_options – Options for training, must be a dataclass of type TrainerOptions.
- distributed_option (DistributedOption) – Options related to distributed training.
- Raises:RuntimeError – If any required component for training is missing or if an error occurs during the training process.
##########
Example
>>> trainer_options = TrainerOptions(ngpu=1, resume=False, use_amp=True,
... train_dtype='float32', grad_noise=False,
... accum_grad=1, grad_clip=5.0,
... grad_clip_type=2.0, log_interval=100,
... no_forward_run=False, use_matplotlib=True,
... use_tensorboard=True, use_wandb=False,
... adapter='none', use_adapter=False,
... save_strategy='all', output_dir='./output',
... max_epoch=10, seed=42, sharded_ddp=False,
... patience=None, keep_nbest_models=1,
... nbest_averaging_interval=1,
... early_stopping_criterion=['valid_loss'],
... best_model_criterion=[['valid', 'loss', 'min']],
... val_scheduler_criterion=['valid_loss'],
... unused_parameters=False,
... wandb_model_log_interval=1,
... create_graph_in_tensorboard=False)
>>> trainer.run(model, optimizers, schedulers, train_iter_factory,
... valid_iter_factory, plot_attention_iter_factory,
... trainer_options, distributed_option)
####### NOTE This method should be called within a context where the model, optimizers, and data iterators are properly initialized.
classmethod train_one_epoch(model: Module, iterator: Iterable[Tuple[List[str], Dict[str, Tensor]]], optimizers: Sequence[Optimizer], schedulers: Sequence[AbsScheduler | None], scaler: GradScaler | None, reporter: SubReporter, summary_writer, options: TrainerOptions, distributed_option: DistributedOption) → bool
Train the model for one epoch.
This method handles the training loop for a single epoch, including forward and backward passes, gradient updates, and logging. It also supports distributed training.
- Parameters:
- model (torch.nn.Module) – The model to be trained.
- iterator (Iterable *[*Tuple *[*List *[*str ] , Dict *[*str , torch.Tensor ] ] ]) – An iterable that provides batches of training data.
- optimizers (Sequence *[*torch.optim.Optimizer ]) – A sequence of optimizers for the model.
- schedulers (Sequence *[*Optional [AbsScheduler ] ]) – A sequence of learning rate schedulers corresponding to the optimizers.
- scaler (Optional *[*GradScaler ]) – A GradScaler for mixed precision training.
- reporter (SubReporter) – A reporter to log training statistics.
- summary_writer – A writer for logging to TensorBoard.
- options (TrainerOptions) – Options for the training process.
- distributed_option (DistributedOption) – Options for distributed training.
- Returns: Returns True if all gradient steps were invalid, otherwise False.
- Return type: bool
##########
Example
>>> model = MyModel()
>>> iterator = data_loader()
>>> optimizers = [torch.optim.Adam(model.parameters())]
>>> schedulers = [torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)]
>>> reporter = SubReporter()
>>> options = TrainerOptions(...)
>>> distributed_option = DistributedOption(...)
>>> all_steps_invalid = Trainer.train_one_epoch(
... model, iterator, optimizers, schedulers, None, reporter, None, options, distributed_option
... )
####### NOTE This method modifies the model’s parameters and may have side effects on the model’s state.
- Raises:AssertionError – If the batch is not of type dict.
classmethod validate_one_epoch(model: Module, iterator: Iterable[Dict[str, Tensor]], reporter: SubReporter, options: TrainerOptions, distributed_option: DistributedOption) → None