espnet3.components.trainers.trainer.ESPnet3LightningTrainer
espnet3.components.trainers.trainer.ESPnet3LightningTrainer
class espnet3.components.trainers.trainer.ESPnet3LightningTrainer(model: ESPnetLightningModule = None, exp_dir: str = None, config: DictConfig | Namespace | Dict[str, Any] = None, best_model_criterion=None)
Bases: object
A wrapper around Lightning’s Trainer to provide ESPnet3-specific integration.
This trainer ensures compatibility with ESPnet’s dataloader, callbacks, and configuration system. It initializes the model, handles weight initialization, sets up the training strategy, logger, plugins, and integrates with ESPnet-specific callbacks and samplers.
config
Training configuration.
- Type: Union[DictConfig, Namespace, Dict[str, Any]]
model
ESPnet3 LightningModule instance.
trainer
Underlying PyTorch Lightning trainer object.
- Type: lightning.Trainer
Initialize the trainer with model, configuration, and training setup.
Sets up weight initialization, accelerator/strategy/logger/profiler/plugins, applies ESPnet-specific dataloader constraints, prepares callbacks, and finally constructs the underlying Lightning Trainer.
- Parameters:
- model (ESPnetLightningModule , optional) – LightningModule to train.
- exp_dir (str , optional) – Experiment directory for logs and checkpoints.
- config (DictConfig | Namespace | Dict *[*str , Any ] , optional) – Training config.
- best_model_criterion (ListConfig , optional) – Criteria for selecting ckpt.
collect_stats(*args, **kwargs)
Collect dataset statistics with the espnet-3’s parallel package.
- Parameters:
- *args – Positional arguments passed to
model.collect_stats(). - **kwargs – Keyword arguments passed to
model.collect_stats().
- *args – Positional arguments passed to
fit(*args, **kwargs)
Start the training loop using Lightning’s fit method.
- Parameters:
- *args – Positional arguments passed to
trainer.fit(). - **kwargs – Keyword arguments passed to
trainer.fit().
- *args – Positional arguments passed to
NOTE
Always uses the internally stored model (self.model) when calling fit.
validate(*args, **kwargs)
Run validation using Lightning’s validate method.
- Parameters:
- *args – Positional arguments passed to
trainer.validate(). - **kwargs – Keyword arguments passed to
trainer.validate().
- *args – Positional arguments passed to
- Returns: Validation results.
- Return type: List[Dict[str, Any]]
