espnet3.components.modeling.lightning_module.ESPnetLightningModule
espnet3.components.modeling.lightning_module.ESPnetLightningModule
class espnet3.components.modeling.lightning_module.ESPnetLightningModule(model, config)
Bases: LightningModule
ESPnet3 LightningModule wrapper for model training and data integration.
The common ESPnet3 model contract is shown below.
loss, stats, weight = model(**batch)Most models should continue to return a single scalar loss tensor. The training loop then behaves exactly like conventional Lightning single-optimizer training.
When multiple optimizers are configured, the same return value is expected, but the loss field must carry optimizer routing information through OptimizationStep. The model still returns one stats dict and one optional weight value; only the type of loss changes.
Example
Single optimizer path. .. code-block:: python
def forward(self,
**
batch): : loss = … stats = {“loss”: loss.detach(), “acc”: acc.detach()} weight = torch.tensor(batch_size, device=loss.device) return loss, stats, weight
GAN-style path updating both optimizers in a single batch. .. code-block:: python
def forward(self,
**
batch): : g_loss = … d_loss = … stats = { <br/>
“generator_loss”: g_loss.detach(), “discriminator_loss”: d_loss.detach(), <br/> } return [ <br/> OptimizationStep(loss=g_loss, name=”generator”), OptimizationStep(loss=d_loss, name=”discriminator”), <br/> ], stats, None
GAN-style path updating only the generator for one batch. .. code-block:: python
def forward(self,
**
batch): : g_loss = … stats = {“generator_loss”: g_loss.detach()} return OptimizationStep(loss=g_loss, name=”generator”), stats, None
Notes
- Returning OptimizationStep with the single optimizer is forbidden.
- In the multi-optimizer path, only optimizers named by returned OptimizationStep objects are touched for that batch. Optimizers omitted from the list are left untouched entirely.
- The order of OptimizationStep entries is the exact backward/step order.
- NaN or Inf in any returned loss causes the whole batch to be skipped on all workers.
Initialize the ESPnet LightningModule wrapper.
collect_stats()
Collect training and validation statistics using ESPnet’s collect_stats.
Requires config.stats_dir to be defined. Saves stats under this directory.
- Raises:AssertionError – If config.stats_dir is not provided.
configure_optimizers()
Configure single-optimizer-path or named multi-path optimizers.
This includes the paired scheduler configuration.
Single-optimizer-path training.
optimizer:
_target_: torch.optim.Adam
lr: 0.001
scheduler:
_target_: torch.optim.lr_scheduler.StepLR
step_size: 10
gamma: 0.5
scheduler_interval: stepMatching model return.
def forward(self, **batch):
loss = ...
stats = {"loss": loss.detach(), "acc": acc.detach()}
weight = torch.tensor(batch_size, device=loss.device)
return loss, stats, weightSingle-optimizer-path schedulers follow standard Lightning behavior. Validation-monitored `ReduceLROnPlateau` example.
optimizer:
_target_: torch.optim.Adam
lr: 0.001
scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
patience: 2
factor: 0.5
scheduler_interval: epoch
scheduler_monitor: valid/lossLightning receives.
- interval=”epoch”
- monitor=”valid/loss”
so it steps the scheduler after validation using the logged valid/loss metric.
Multiple-path training is enabled by configuring named optimizers and schedulers. The names become the routing keys used by OptimizationStep(name=…).
optimizers:
generator:
optimizer:
_target_: torch.optim.Adam
lr: 0.0002
params: generator
accum_grad_steps: 1
step_every_n_iters: 1
gradient_clip_val: 1.0
gradient_clip_algorithm: norm
discriminator:
optimizer:
_target_: torch.optim.Adam
lr: 0.0002
params: discriminator
accum_grad_steps: 1
step_every_n_iters: 1
schedulers:
generator:
scheduler:
_target_: torch.optim.lr_scheduler.LinearLR
start_factor: 1.0
end_factor: 0.5
total_iters: 1000
interval: step
discriminator:
scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
patience: 2
factor: 0.5
interval: epoch
monitor: valid/discriminator/lossUpdate both branches in one batch.
return [
OptimizationStep(loss=g_loss, name="generator"),
OptimizationStep(loss=d_loss, name="discriminator"),
], {
"generator_loss": g_loss.detach(),
"discriminator_loss": d_loss.detach(),
}, NoneUpdate only one branch.
return OptimizationStep(loss=g_loss, name="generator"), {
"generator_loss": g_loss.detach(),
}, NoneImportant rules.
- Single-optimizer-path training must return a tensor loss directly.
- Multiple-path training must return OptimizationStep or
list[OptimizationStep] as loss so that ESPnet3 knows which optimizer should be used to update parameters.
Optimizer and scheduler names must match exactly. Valid.
optimizers: {generator: {...}, discriminator: {...}} schedulers: {generator: {...}, discriminator: {...}}Error example.
optimizers: {generator: {...}, discriminator: {...}} schedulers: {generator: {...}, decoder: {...}}In the multiple-path configuration, gradient clipping is configured per optimizer via gradient_clip_val and gradient_clip_algorithm. Trainer-level global clipping settings must not be used.
monitor names must refer to logged metric keys of the form train/<stats-key> or valid/<stats-key>, including automatically logged multi-path losses such as valid/generator/loss and valid/discriminator/loss.
monitor is only used with epoch-based schedulers. Step-based schedulers are always called as scheduler.step() after a successful optimizer update and do not receive metric inputs.
DeepSpeed is rejected in the multi-path configuration because Lightning’s DeepSpeed strategy does not support multiple optimizers/schedulers.
load_state_dict(state_dict, strict=True)
Load state dict into the model.
on_load_checkpoint(checkpoint: Dict[str, object]) → None
Restore custom per-optimizer runtime state from checkpoints.
on_save_checkpoint(checkpoint: Dict[str, object]) → None
Persist custom per-optimizer runtime state in checkpoints.
Lightning already saves and restores the instantiated optimizer and scheduler state_dict() objects, so this hook only stores the extra runtime state introduced by ESPnet3’s named multi-optimizer path.
- accum_counter
- update_step
Scheduler state is not saved here because Lightning’s checkpoint already contains each scheduler’s internal state. Additional scheduler fields only need to be added here if ESPnet3 introduces custom scheduler-side runtime state that Lightning does not know about.
on_train_epoch_end() → None
Step epoch-based schedulers after metrics have been aggregated.
This hook primarily exists for the multiple-loss / multiple-optimizer path, where ESPnet3 owns the optimizer and scheduler orchestration. The single-optimizer path keeps Lightning automatic optimization enabled, so Lightning handles epoch-end scheduler stepping there.
state_dict(*args, **kwargs)
Return the state dict of the model.
train_dataloader()
Build the training DataLoader using ESPnet’s DataLoaderBuilder.
- Returns: The training DataLoader.
- Return type: DataLoader
training_step(batch, batch_idx)
Training step logic.
val_dataloader()
Build the validation DataLoader using ESPnet’s DataLoaderBuilder.
- Returns: The validation DataLoader.
- Return type: DataLoader
validation_step(batch, batch_idx)
Run the validation step logic.
