espnet3.components.callbacks.AverageCheckpointsCallback
Less than 1 minute
espnet3.components.callbacks.AverageCheckpointsCallback
class espnet3.components.callbacks.AverageCheckpointsCallback(output_dir, best_ckpt_callbacks)
Bases: Callback
A custom PyTorch Lightning callback that performs weight averaging over the top-K checkpoints (according to specified metrics) at the end of training.
This can be useful to smooth out fluctuations in weights across the best-performing models and can lead to improved generalization performance at inference time.
Behavior: : - Loads the state_dict from each of the top-K checkpoints saved by given ModelCheckpoint callbacks.
- Averages the model parameters (keys starting with model.).
- Ignores or simply accumulates integer-type parameters (e.g., BatchNorm’s num_batches_tracked).
- Saves the averaged model as a .pth file in output_dir.
- Parameters:
- output_dir (str or Path) – The directory where the averaged model will be saved.
- best_ckpt_callbacks (List *[*ModelCheckpoint ]) – A list of ModelCheckpoint callbacks whose top-K checkpoints will be used for averaging. Each callback must have best_k_models populated.
Notes
- Only keys that start with model. are included in the averaging.
- The final filename will be: : {monitor_name}.ave_{K}best.pth
- This callback only runs on the global rank 0 process : (for distributed training).
Example
>>> avg_ckpt_cb = AverageCheckpointsCallback(
... output_dir="checkpoints/",
... best_ckpt_callbacks=[val_loss_ckpt_cb, acc_ckpt_cb]
... )
>>> trainer = Trainer(callbacks=[avg_ckpt_cb])on_fit_end(trainer, pl_module)
Called when fit ends.
