espnet2.train.reporter.Reporter
espnet2.train.reporter.Reporter
class espnet2.train.reporter.Reporter(epoch: int = 0)
Bases: object
Reporter class.
This class is responsible for reporting training and evaluation statistics during the training process of a machine learning model. It keeps track of various metrics over multiple epochs and allows for organized logging and visualization of these metrics.
epoch
The current epoch number.
- Type: int
stats
A dictionary to store statistics for each epoch and key.
- Type: dict
############################################# Examples
>>> reporter = Reporter()
>>> with reporter.observe('train') as sub_reporter:
... for batch in iterator:
... stats = dict(loss=0.2)
... sub_reporter.register(stats)
get_epoch() → int
Returns the current epoch number.
set_epoch(epoch
int) -> None: Sets the current epoch number.
observe(key
str, epoch: Optional[int] = None) -> ContextManager[SubReporter]: Context manager to observe metrics for a specific key.
start_epoch(key
str, epoch: Optional[int] = None) -> SubReporter: Initializes a new SubReporter for the specified key and epoch.
finish_epoch(sub_reporter
SubReporter) -> None: Finalizes and records the statistics collected in the given SubReporter.
sort_epochs_and_values(key
str, key2: str, mode: str) -> List[Tuple[int, float]]: Returns a list of epochs and their corresponding values sorted by mode.
check_early_stopping(patience
int, key1: str, key2: str, mode: str, : epoch: Optional[int] = None, logger: Optional[logging.Logger] = None) -> bool:
Checks if early stopping criteria are met based on the specified key.
has(key
str, key2: str, epoch: Optional[int] = None) -> bool: Checks if the specified keys exist in the statistics for the given epoch.
log_message(epoch
Optional[int] = None) -> str: Generates a formatted log message for the specified epoch.
get_value(key
str, key2: str, epoch: Optional[int] = None): Retrieves the value for the specified keys and epoch.
get_keys(epoch
Optional[int] = None) -> Tuple[str, …]: Returns the first-level keys (e.g., ‘train’, ‘eval’) for the specified epoch.
get_keys2(key
str, epoch: Optional[int] = None) -> Tuple[str, …]: Returns the second-level keys (e.g., ‘loss’, ‘acc’) for the specified key and epoch.
get_all_keys(epoch
Optional[int] = None) -> Tuple[Tuple[str, str], …]: Returns all key pairs for the specified epoch.
matplotlib_plot(output_dir
Union[str, Path]) -> None: Plots statistics using Matplotlib and saves the images to the specified directory.
tensorboard_add_scalar(summary_writer, epoch
Optional[int] = None, : key1: Optional[str] = None) -> None:
Adds scalar values to TensorBoard for visualization.
wandb_log(epoch
Optional[int] = None) -> None: Logs metrics to Weights & Biases for tracking.
state_dict() → dict
Returns the state of the Reporter as a dictionary.
load_state_dict(state_dict
dict) -> None: Loads the state of the Reporter from a given dictionary.
#
check_early_stopping(patience
#
finish_epoch(sub_reporter
Finalizes the statistics for the current epoch and updates the report.
This method calculates the mean values of the statistics collected during the epoch and stores them in the overall stats dictionary. It also records the elapsed time for the epoch and handles GPU memory metrics if applicable.
- Parameters:sub_reporter (SubReporter) – The sub-reporter instance that contains the statistics for the current epoch.
- Raises:RuntimeError – If the epoch of the sub-reporter does not match the current epoch of the reporter.
############################################# Examples
>>> reporter = Reporter()
>>> with reporter.observe('train') as sub_reporter:
... # Register statistics during training
... sub_reporter.register({'loss': 0.5})
>>> reporter.finish_epoch(sub_reporter)
############# NOTE This method is intended to be called after the observation of the epoch has been completed, and it ensures that the statistics are properly aggregated and stored.
#
get_all_keys(epoch
Returns all keys from the current epoch’s statistics.
The returned keys are tuples containing two elements: the first element corresponds to the main key (e.g., ‘train’, ‘eval’), and the second element corresponds to the specific metric (e.g., ‘loss’, ‘acc’).
- Parameters:epoch (int , optional) – The epoch from which to retrieve the keys. If None, the current epoch is used. Defaults to None.
- Returns: A tuple of tuples containing all key pairs from the specified epoch.
- Return type: Tuple[Tuple[str, str], …]
############################################# Examples
>>> reporter = Reporter()
>>> with reporter.observe('train') as sub_reporter:
... sub_reporter.register({'loss': 0.2})
... sub_reporter.next()
>>> reporter.get_all_keys()
(('train', 'loss'),)
############# NOTE If the specified epoch has no recorded statistics, an empty tuple will be returned.
get_best_epoch(key: str, key2: str, mode: str, nbest: int = 0) → int
Retrieve the best epoch based on a specified metric.
This method sorts the epochs based on the specified metric and returns the epoch corresponding to the best value. The best value can be either the minimum or maximum, depending on the specified mode. The nbest argument allows you to retrieve the nth best epoch (e.g., 0 for the best, 1 for the second best, etc.).
- Parameters:
- key – The main key for the metric (e.g., ‘train’, ‘eval’).
- key2 – The sub-key for the specific metric (e.g., ‘loss’, ‘acc’).
- mode – The mode for selecting the best value; it should be either ‘min’ or ‘max’.
- nbest – The index of the best epoch to retrieve. Defaults to 0, which retrieves the best epoch.
- Returns: The epoch number that corresponds to the best value based on the specified criteria.
############################################# Examples
>>> best_epoch = reporter.get_best_epoch('eval', 'loss', 'min')
>>> second_best_epoch = reporter.get_best_epoch('train', 'acc', 'max', nbest=1)
- Raises:
- ValueError – If mode is not ‘min’ or ‘max’.
- KeyError – If the specified key or sub-key does not exist in the reported statistics.
get_epoch() → int
Returns the current epoch number.
This method retrieves the epoch number that the reporter is currently in.
- Returns: The current epoch number.
- Return type: int
############################################# Examples
>>> reporter = Reporter(epoch=5)
>>> current_epoch = reporter.get_epoch()
>>> print(current_epoch)
5
#
get_keys(epoch
Returns keys1 e.g. train, eval.
- Parameters:epoch (int , optional) – The epoch number for which to retrieve the keys. If None, the current epoch is used.
- Returns: A tuple of keys representing the first-level : keys (e.g., ‘train’, ‘eval’) in the stats dictionary for the specified epoch.
- Return type: Tuple[str, …]
############################################# Examples
>>> reporter = Reporter()
>>> reporter.start_epoch('train')
>>> reporter.register({'loss': 0.2})
>>> reporter.get_keys()
('train',)
>>> reporter.start_epoch('eval')
>>> reporter.register({'loss': 0.1})
>>> reporter.get_keys()
('train', 'eval')
#
get_keys2(key
Returns keys2 e.g. loss, acc.
This method retrieves the second-level keys associated with a specified first-level key (e.g., ‘train’ or ‘eval’) for a given epoch. The keys are filtered to exclude reserved keys like ‘time’ and ‘total_count’.
- Parameters:
- key (str) – The first-level key to retrieve second-level keys from.
- epoch (int , optional) – The epoch number to retrieve keys for. If None, the current epoch is used.
- Returns: A tuple of second-level keys corresponding to the specified first-level key for the specified epoch.
- Return type: Tuple[str, …]
############################################# Examples
>>> reporter = Reporter()
>>> with reporter.observe('train') as sub_reporter:
... sub_reporter.register({'loss': 0.2, 'acc': 0.8})
>>> reporter.get_keys2('train')
('loss', 'acc')
- Raises:
- KeyError – If the specified key does not exist in the stats for the
- specified epoch. –
#
get_value(key
Retrieves the value associated with a specified key and key2 for a given epoch.
- Parameters:
- key (str) – The primary key representing a specific metric category, such as ‘train’ or ‘eval’.
- key2 (str) – The secondary key representing a specific metric within the category, such as ‘loss’ or ‘accuracy’.
- epoch (int , optional) – The epoch number to retrieve the value from. If not provided, the current epoch is used.
- Returns: The value associated with the specified key and key2 for the given epoch.
- Raises:KeyError – If the specified key or key2 does not exist in the statistics.
############################################# Examples
>>> reporter = Reporter()
>>> reporter.register({'loss': 0.2})
>>> value = reporter.get_value('train', 'loss')
>>> print(value) # Output: 0.2
############# NOTE Ensure that the statistics for the specified key and key2 have been registered before attempting to retrieve the value.
#
has(key
Check if the specified key and key2 exist in the stats for a given epoch.
This method verifies whether the specified statistics are recorded for the current epoch or a specific epoch. It checks if both the primary key and the secondary key exist in the stats dictionary.
- Parameters:
- key (str) – The primary key, typically representing a phase (e.g., ‘train’).
- key2 (str) – The secondary key, usually representing a specific metric (e.g., ‘loss’).
- epoch (int , optional) – The epoch to check. If None, the current epoch is used. Defaults to None.
- Returns: True if the keys exist in the stats for the specified epoch, : otherwise False.
- Return type: bool
############################################# Examples
>>> reporter = Reporter()
>>> reporter.has('train', 'loss') # Check for the current epoch
False
>>> with reporter.observe('train') as sub_reporter:
... sub_reporter.register({'loss': 0.2})
>>> reporter.has('train', 'loss') # Check after registration
True
- Raises:KeyError – If the specified keys do not exist in the stats for the epoch.
#
load_state_dict(state_dict
Loads the state dictionary into the Reporter instance.
This method updates the internal state of the Reporter instance with the provided state dictionary. It expects the state dictionary to contain the epoch and stats of the Reporter. This is typically used for restoring the state of a model during training or evaluation.
Parameters:state_dict (dict) –
A dictionary containing the state of the Reporter. It should have the following keys:
- ”epoch” (int): The current epoch of the Reporter.
- ”stats” (dict): The statistics collected during training, indexed by epoch and other relevant keys.
############################################# Examples
>>> reporter = Reporter()
>>> state = {"epoch": 5, "stats": {0: {"train": {"loss": 0.1}}}}
>>> reporter.load_state_dict(state)
>>> print(reporter.get_epoch()) # Output: 5
>>> print(reporter.stats) # Output: {0: {'train': {'loss': 0.1}}}
#
log_message(epoch
Generate a formatted log message for the current epoch.
This method constructs a string message that summarizes the reported values for the specified range of batches within the current epoch. The message includes statistics for each key that has been registered during the epoch.
- Parameters:
- start (int , optional) – The starting index of the batch range. If None, defaults to 0. If negative, counts from the end.
- end (int , optional) – The ending index of the batch range. If None, defaults to the current count of batches.
- Returns: A formatted log message summarizing the statistics : for the specified range of batches.
- Return type: str
- Raises:RuntimeError – If the reporting process has already been finished.
############################################# Examples
>>> sub_reporter = SubReporter("train", 1, 10)
>>> sub_reporter.register({"loss": 0.5})
>>> print(sub_reporter.log_message(0, 1))
"1epoch:train:1-1batch: loss=0.500"
>>> sub_reporter.register({"loss": 0.3})
>>> print(sub_reporter.log_message(0, 2))
"1epoch:train:1-2batch: loss=0.400"
############# NOTE The start and end indices must be within the bounds of the registered statistics.
#
matplotlib_plot(output_dir
Plot stats using Matplotlib and save images.
This method generates plots for each metric collected during training or evaluation and saves them as PNG images in the specified output directory. The images are named according to the metric (key2) being plotted.
- Parameters:output_dir (Union *[*str , Path ]) – The directory where the plots will be saved.
############################################# Examples
>>> reporter = Reporter()
>>> # Assume stats have been registered in reporter
>>> reporter.matplotlib_plot('output/plots')
############# NOTE Ensure that the Matplotlib library is properly installed and available in the environment where this function is executed.
#
observe(key
Observe a specific key during a training or evaluation epoch.
This context manager allows for the registration of statistics for a specific key (e.g., ‘train’ or ‘eval’) during an epoch. It will yield a SubReporter instance that can be used to collect statistics, which will be finalized once the context manager exits.
- Parameters:
- key (str) – The key for which statistics are being collected.
- epoch (int , optional) – The epoch number. If None, the current epoch will be used.
- Yields:SubReporter – An instance of SubReporter to register statistics.
- Raises:ValueError – If the provided epoch is negative.
############################################# Examples
>>> reporter = Reporter()
>>> with reporter.observe('train') as sub_reporter:
... for batch in iterator:
... stats = dict(loss=0.2)
... sub_reporter.register(stats)
############# NOTE The statistics collected during this observation will be finalized and stored in the main Reporter instance once the context manager exits.
#
set_epoch(epoch
Sets the current epoch of the reporter.
This method updates the epoch counter to the specified value. The epoch must be a non-negative integer. If a negative value is provided, a ValueError will be raised.
- Parameters:epoch (int) – The new epoch number to set. Must be 0 or more.
- Raises:ValueError – If the provided epoch is less than 0.
############################################# Examples
>>> reporter = Reporter()
>>> reporter.set_epoch(1)
>>> reporter.get_epoch()
1
>>> reporter.set_epoch(-1) # This will raise a ValueError
sort_epochs(key: str, key2: str, mode: str) → List[int]
Sort and return the epochs based on the specified metric.
This method retrieves the epochs associated with the specified key and key2, and sorts them in either ascending or descending order based on the mode specified. The mode can be ‘min’ for finding the epochs with the minimum value of the specified metric or ‘max’ for finding the epochs with the maximum value.
- Parameters:
- key (str) – The primary key representing the type of data (e.g., ‘train’, ‘eval’).
- key2 (str) – The secondary key representing the specific metric to evaluate (e.g., ‘loss’, ‘accuracy’).
- mode (str) – Specifies the sorting order. Should be either ‘min’ or ‘max’.
- Returns: A list of epochs sorted based on the specified metric.
- Return type: List[int]
- Raises:
- ValueError – If the mode is not ‘min’ or ‘max’.
- KeyError – If the specified key or key2 does not exist in the
- recorded statistics. –
############################################# Examples
>>> reporter = Reporter()
>>> # Assuming reporter has recorded stats for 'train' and 'loss'
>>> sorted_epochs = reporter.sort_epochs('train', 'loss', 'min')
>>> print(sorted_epochs) # Output: [1, 2, 3] (example output)
#
sort_epochs_and_values(key
Return the epoch which resulted in the best value.
This method sorts the epochs based on the specified metric and mode, returning the epochs along with their corresponding values.
- Parameters:
- key (str) – The primary key to access the statistics (e.g., ‘train’ or ‘eval’).
- key2 (str) – The secondary key to specify the metric (e.g., ‘loss’ or ‘accuracy’).
- mode (str) – The mode for sorting; must be either ‘min’ or ‘max’. ‘min’ returns the epochs with the smallest values, while ‘max’ returns the epochs with the largest values.
- Returns: A list of tuples, each containing an epoch number and its corresponding value, sorted according to the specified mode.
- Return type: List[Tuple[int, float]]
- Raises:
- ValueError – If the mode is not ‘min’ or ‘max’.
- KeyError – If the specified key or key2 is not found in the statistics.
############################################# Examples
>>> val = reporter.sort_epochs_and_values('eval', 'loss', 'min')
>>> e_1best, v_1best = val[0]
>>> e_2best, v_2best = val[1]
############# NOTE Ensure that the statistics for the specified keys have been registered before calling this method.
sort_values(key: str, key2: str, mode: str) → List[float]
Return a list of values sorted by the specified key and mode.
This method retrieves values for a specified key (e.g., ‘train’) and a second key (e.g., ‘loss’) from the stats, sorts them according to the specified mode (‘min’ or ‘max’), and returns the sorted values as a list.
- Parameters:
- key (str) – The primary key for which to retrieve values.
- key2 (str) – The secondary key for which to retrieve values.
- mode (str) – The sorting mode, either ‘min’ or ‘max’.
- Returns: A list of values sorted according to the specified mode.
- Return type: List[float]
- Raises:
- ValueError – If mode is not ‘min’ or ‘max’.
- KeyError – If the specified key or key2 is not found in the
- stats. –
############################################# Examples
>>> reporter = Reporter()
>>> reporter.sort_values('train', 'loss', 'min')
[0.1, 0.2, 0.3]
>>> reporter.sort_values('eval', 'accuracy', 'max')
[0.95, 0.93, 0.90]
#
start_epoch(key
Starts the reporting for a new epoch.
This method initializes a new SubReporter for tracking statistics during the epoch specified by the key parameter. If an epoch is specified, it will be set as the current epoch for the Reporter.
- Parameters:
- key (str) – A string identifier for the reporting context, such as ‘train’ or ‘eval’.
- epoch (int , optional) – The epoch number to start. If not provided, the current epoch will be used.
- Returns: An instance of SubReporter for the current epoch reporting.
- Return type:SubReporter
- Raises:
- ValueError – If epoch is less than 0.
- RuntimeError – If the previous epoch’s statistics are missing or if the epoch is not properly incremented.
############################################# Examples
>>> reporter = Reporter()
>>> sub_reporter = reporter.start_epoch('train', 1)
>>> sub_reporter.key
'train'
>>> sub_reporter.epoch
1
############# NOTE It is essential to call finish_epoch after using the SubReporter to finalize and store the statistics for the epoch.
state_dict()
Returns the current state of the Reporter instance.
The state is represented as a dictionary containing the following keys:
- ‘stats’: A dictionary of recorded statistics for each epoch and key.
- ‘epoch’: The current epoch number.
This method is useful for saving the current state of the Reporter, allowing for resuming from a checkpoint later.
############################################# Examples
>>> reporter = Reporter()
>>> # Simulate some training and logging
>>> with reporter.observe('train') as sub_reporter:
... sub_reporter.register({'loss': 0.5})
>>> state = reporter.state_dict()
>>> print(state['epoch']) # Output: 0
>>> print(state['stats']) # Output: {'train': {'loss': 0.5, ...}}
- Returns: A dictionary containing the state of the Reporter.
- Return type: dict
#
tensorboard_add_scalar(summary_writer, epoch
Adds scalar values to TensorBoard for visualization.
This method logs scalar statistics to the provided TensorBoard summary writer. The statistics are recorded for the current epoch or a specified epoch.
- Parameters:
- summary_writer – A TensorBoard summary writer instance used to log scalar values.
- start (int , optional) – The starting index from which to log the statistics. If None, it defaults to 0. If negative, it will start from self.count + start.
############################################# Examples
>>> from torch.utils.tensorboard import SummaryWriter
>>> writer = SummaryWriter()
>>> reporter = Reporter()
>>> with reporter.observe('train') as sub_reporter:
... for batch in iterator:
... stats = dict(loss=0.2)
... sub_reporter.register(stats)
... sub_reporter.tensorboard_add_scalar(writer)
>>> writer.close()
- Raises:AssertionError – If the length of the statistics list does not match the expected count.
#
wandb_log(epoch
Logs the reported values to Weights and Biases (wandb).
This method aggregates the statistics collected in the current epoch and logs them to the Weights and Biases dashboard. The logged values include metrics prefixed with their respective categories (e.g., ‘train/’, ‘valid/’) and the total count of iterations.
- Parameters:start (int , optional) – The starting index for the statistics to log. If None, logging will start from the beginning. If negative, it will start from the end of the statistics.
- Raises:AssertionError – If the lengths of the statistics lists do not match the current count of reported values.
############################################# Examples
>>> reporter = Reporter()
>>> with reporter.observe('train') as sub_reporter:
... sub_reporter.register({'loss': 0.5})
... sub_reporter.wandb_log()
############# NOTE Ensure that wandb is properly initialized before calling this method.