espnet3.components.callbacks.default_callbacks.AverageCheckpointsCallback
Less than 1 minute
espnet3.components.callbacks.default_callbacks.AverageCheckpointsCallback
class espnet3.components.callbacks.default_callbacks.AverageCheckpointsCallback(output_dir, best_ckpt_callbacks)
Bases: Callback
A custom callback for weight averaging over the top-K checkpoints.
This can be useful to smooth out fluctuations in weights across the best-performing models and can lead to improved generalization performance at inference time.
Behavior. : - Loads the state_dict from each of the top-K checkpoints saved by given ModelCheckpoint callbacks.
- Averages the model parameters (keys starting with
model.). - Ignores or simply accumulates integer-type parameters (e.g., BatchNorm’s
num_batches_tracked). - Saves the averaged model as a
.pthfile inoutput_dir.
- Parameters:
- output_dir (str or Path) – The directory where the averaged model will be saved.
- best_ckpt_callbacks (List *[*ModelCheckpoint ]) – A list of ModelCheckpoint callbacks whose top-K checkpoints will be used for averaging. Each callback must have
best_k_modelspopulated.
Notes
- Only keys that start with
model.are included in the averaging. - The final filename will be: :
{monitor_name}.ave_{K}best.pth - This callback only runs on the global rank 0 process : (for distributed training).
Example
>>> avg_ckpt_cb = AverageCheckpointsCallback(
... output_dir="checkpoints/",
... best_ckpt_callbacks=[val_loss_ckpt_cb, acc_ckpt_cb]
... )
>>> trainer = Trainer(callbacks=[avg_ckpt_cb])Initialize AverageCheckpointsCallback object.
on_validation_end(trainer, pl_module)
At the end of validation, average the top-K checkpoints and save.
