espnet3.components.modeling.lightning_module.ESPnetLightningModule
espnet3.components.modeling.lightning_module.ESPnetLightningModule
class espnet3.components.modeling.lightning_module.ESPnetLightningModule(model, config)
Bases: LightningModule
ESPnet3 LightningModule wrapper for model training and data integration.
This wrapper keeps the common ESPnet3 model contract unchanged:
python loss, stats, weight = model(**batch)
Most models should continue to return a single scalar loss tensor. The training loop then behaves exactly like conventional Lightning single-optimizer training.
When multiple optimizers are configured, the same return value is expected, but the loss field must carry optimizer routing information through OptimizationStep. The model still returns one stats dict and one optional weight value; only the type of loss changes.
Example
Single optimizer path:
```python def forward(self,
**batch):
loss = … stats = {“loss”: loss.detach(), “acc”: acc.detach()} weight = torch.tensor(batch_size, device=loss.device) return loss, stats, weight
```GAN-style path updating both optimizers in a single batch:
```python def forward(self,
**batch):
g_loss = … d_loss = … stats = {
“generator_loss”: g_loss.detach(), “discriminator_loss”: d_loss.detach(),
} return [
OptimizationStep(loss=g_loss, name=”generator”), OptimizationStep(loss=d_loss, name=”discriminator”),
], stats, None
```GAN-style path updating only the generator for one batch:
```python def forward(self,
**batch):
g_loss = … stats = {“generator_loss”: g_loss.detach()} return OptimizationStep(loss=g_loss, name=”generator”), stats, None
```Notes
- Returning OptimizationStep with the single optimizer is forbidden.
- In the multi-optimizer path, only optimizers named by returned OptimizationStep objects are touched for that batch. Optimizers omitted from the list are left untouched entirely.
- The order of OptimizationStep entries is the exact backward/step order.
- NaN or Inf in any returned loss causes the whole batch to be skipped on all workers.
Initialize the ESPnet LightningModule wrapper.
collect_stats()
Collect training and validation statistics using ESPnet’s collect_stats.
Requires config.stats_dir to be defined. Saves stats under this directory.
- Raises:AssertionError – If config.stats_dir is not provided.
configure_optimizers()
Configure single-optimizer-path or named multi-path optimizers.
This includes the paired scheduler configuration.
Single-optimizer-path training keeps the traditional ESPnet contract:
```yaml optimizer:
_target_: torch.optim.Adam lr: 0.001
scheduler: : _target_: torch.optim.lr_scheduler.StepLR step_size: 10 gamma: 0.5
scheduler_interval: step
```The model keeps returning a plain tensor loss:
```python def forward(self,
**batch):
loss = … stats = {“loss”: loss.detach(), “acc”: acc.detach()} weight = torch.tensor(batch_size, device=loss.device) return loss, stats, weight
```Single-optimizer-path schedulers follow standard Lightning behavior. Example with a validation-monitored ReduceLROnPlateau:
```yaml optimizer:
_target_: torch.optim.Adam lr: 0.001
scheduler: : _target_: torch.optim.lr_scheduler.ReduceLROnPlateau patience: 2 factor: 0.5
scheduler_interval: epoch scheduler_monitor: valid/loss
```In this case Lightning receives:
- interval=”epoch”
- monitor=”valid/loss”
so it steps the scheduler after validation using the logged valid/loss metric.
Multiple-path training is enabled by configuring named optimizers and schedulers. The names become the routing keys used by OptimizationStep(name=…).
```yaml optimizers:
generator: : optimizer: : _target_: torch.optim.Adam lr: 0.0002 <br/> params: generator accum_grad_steps: 1 step_every_n_iters: 1 gradient_clip_val: 1.0 gradient_clip_algorithm: norm
discriminator: : optimizer: : _target_: torch.optim.Adam lr: 0.0002 <br/> params: discriminator accum_grad_steps: 1 step_every_n_iters: 1
schedulers: : generator: : scheduler: : _target_: torch.optim.lr_scheduler.LinearLR start_factor: 1.0 end_factor: 0.5 total_iters: 1000 <br/> interval: step <br/> discriminator: : scheduler: : _target_: torch.optim.lr_scheduler.ReduceLROnPlateau patience: 2 factor: 0.5 <br/> interval: epoch monitor: valid/discriminator/loss
```The matching model return may update both branches in one batch:
```python return [
OptimizationStep(loss=g_loss, name=”generator”), OptimizationStep(loss=d_loss, name=”discriminator”),
], { : “generator_loss”: g_loss.detach(), “discriminator_loss”: d_loss.detach(),
}, None
```Or update only one branch:
```python return OptimizationStep(loss=g_loss, name=”generator”), {
“generator_loss”: g_loss.detach(),
}, None
```Important rules:
- Single-optimizer-path training must return a tensor loss directly.
- Multiple-path training must return OptimizationStep or
list[OptimizationStep] as loss so that ESPnet3 knows which optimizer should be used to update parameters.
- Optimizer and scheduler names must match exactly. Valid:
yaml optimizers: {generator: {...}, discriminator: {...}} schedulers: {generator: {...}, discriminator: {...}}Error example:yaml optimizers: {generator: {...}, discriminator: {...}} schedulers: {generator: {...}, decoder: {...}} - In the multiple-path configuration, gradient clipping is configured per optimizer via gradient_clip_val and gradient_clip_algorithm. Trainer-level global clipping settings must not be used.
- monitor names must refer to logged metric keys of the form train/<stats-key> or valid/<stats-key>, including automatically logged multi-path losses such as valid/generator/loss and valid/discriminator/loss.
- monitor is only used with epoch-based schedulers. Step-based schedulers are always called as scheduler.step() after a successful optimizer update and do not receive metric inputs.
- DeepSpeed is rejected in the multi-path configuration because Lightning’s DeepSpeed strategy does not support multiple optimizers/schedulers.
load_state_dict(state_dict, strict=True)
Load state dict into the model.
on_load_checkpoint(checkpoint: Dict[str, object]) → None
Restore custom per-optimizer runtime state from checkpoints.
on_save_checkpoint(checkpoint: Dict[str, object]) → None
Persist custom per-optimizer runtime state in checkpoints.
Lightning already saves and restores the instantiated optimizer and scheduler state_dict() objects, so this hook only stores the extra runtime state introduced by ESPnet3’s named multi-optimizer path:
- accum_counter
- update_step
Scheduler state is not saved here because Lightning’s checkpoint already contains each scheduler’s internal state. Additional scheduler fields only need to be added here if ESPnet3 introduces custom scheduler-side runtime state that Lightning does not know about.
on_train_epoch_end() → None
Step epoch-based schedulers after metrics have been aggregated.
This hook primarily exists for the multiple-loss / multiple-optimizer path, where ESPnet3 owns the optimizer and scheduler orchestration. The single-optimizer path keeps Lightning automatic optimization enabled, so Lightning handles epoch-end scheduler stepping there.
state_dict(*args, **kwargs)
Return the state dict of the model.
train_dataloader()
Build the training DataLoader using ESPnet’s DataLoaderBuilder.
- Returns: The training DataLoader.
- Return type: DataLoader
training_step(batch, batch_idx)
Training step logic.
val_dataloader()
Build the validation DataLoader using ESPnet’s DataLoaderBuilder.
- Returns: The validation DataLoader.
- Return type: DataLoader
validation_step(batch, batch_idx)
Run the validation step logic.
