espnet3.components.callbacks.default_callbacks.get_default_callbacks
Less than 1 minute
espnet3.components.callbacks.default_callbacks.get_default_callbacks
espnet3.components.callbacks.default_callbacks.get_default_callbacks(exp_dir: str = './exp', log_interval: int = 500, best_model_criterion: List[Tuple[str, int, str]] | List[List] = [('valid/loss', 3, 'min')]) → List[Callback]
Return a list of callbacks tailored for most training workflows.
Includes: : - ModelCheckpoint for saving the last model checkpoint (save_last)
- One or more <br/>
`
<br/> ModelCheckpoint`s for saving the top-K checkpoints according to : specific metrics
- AverageCheckpointsCallback to compute and save the average model from top-K : checkpoints
- LearningRateMonitor to track and log learning rates during training
- TQDMProgressBar to show a rich progress bar during training
- Parameters:
exp_dir (str) – Directory to store checkpoints and logs.
log_interval (int) – Frequency (in training steps) to refresh the progress bar.
best_model_criterion (List *[*Tuple *[*str , int , str ] ]) –
A list of criteria for saving top-K checkpoints. Each item is a tuple: (name, top_k, mode), where:
- name (str): The name of the validation value to monitor
(e.g., “val/loss”).
- top_k (int): Number of best models to keep.
- mode (str): “min” to keep models with lowest value, “max” for highest.
- Returns: A list of callbacks to be passed to the PyTorch Lightning : Trainer.
- Return type: List[Callback]
Example
>>> from default_callbacks import get_default_callbacks
>>> callbacks = get_default_callbacks(
... exp_dir="./exp",
... log_interval=100,
... best_model_criterion=[("val/loss", 5, "min"), ("val/acc", 3, "max")]
... )
>>> trainer = Trainer(callbacks=callbacks, ...)