espnet3.components.trainers.trainer.ESPnet3LightningTrainer
espnet3.components.trainers.trainer.ESPnet3LightningTrainer
class espnet3.components.trainers.trainer.ESPnet3LightningTrainer(model: ESPnetLightningModule = None, exp_dir: str = None, config: DictConfig | Namespace | Dict[str, Any] = None, best_model_criterion=None)
Bases: object
A wrapper around Lightning’s Trainer to provide ESPnet3-specific integration.
This trainer ensures compatibility with ESPnet’s dataloader, callbacks, and configuration system. It initializes the model, handles weight initialization, sets up the training strategy, logger, plugins, and integrates with ESPnet-specific callbacks and samplers.
config
Training configuration.
- Type: Union[DictConfig, Namespace, Dict[str, Any]]
model
ESPnet3 LightningModule instance.
trainer
Underlying PyTorch Lightning trainer object.
- Type: lightning.Trainer
Initialize the trainer with model, configuration, and training setup.
Sets up weight initialization, accelerator/strategy/logger/profiler/plugins, applies ESPnet-specific dataloader constraints, prepares callbacks, and finally constructs the underlying Lightning Trainer.
- Parameters:
- model (ESPnetLightningModule , optional) – LightningModule to train.
- exp_dir (str , optional) – Experiment directory for logs and checkpoints.
- config (DictConfig | Namespace | Dict *[*str , Any ] , optional) – Training config.
- best_model_criterion (ListConfig , optional) – Criteria for selecting ckpt.
collect_stats(*args, **kwargs)
Collect dataset statistics with the espnet-3’s parallel package.
- Parameters:
- *args – Positional arguments passed to model.collect_stats().
- **kwargs – Keyword arguments passed to model.collect_stats().
fit(*args, **kwargs)
Start the training loop using Lightning’s fit method.
- Parameters:
- *args – Positional arguments passed to trainer.fit().
- **kwargs – Keyword arguments passed to trainer.fit().
NOTE
Always uses the internally stored model (self.model) when calling fit.
validate(*args, **kwargs)
Run validation using Lightning’s validate method.
- Parameters:
- *args – Positional arguments passed to trainer.validate().
- **kwargs – Keyword arguments passed to trainer.validate().
- Returns: Validation results.
- Return type: List[Dict[str, Any]]
