espnet2.train.reporter.SubReporter
espnet2.train.reporter.SubReporter
class espnet2.train.reporter.SubReporter(key: str, epoch: int, total_count: int)
Bases: object
This class is used in Reporter.
It is responsible for collecting and managing statistics during training iterations for a specific key (e.g., ‘train’ or ‘valid’).
key
The identifier for the current reporting key.
- Type: str
epoch
The current epoch number.
- Type: int
start_time
The time when the reporting for this key started.
- Type: float
stats
A dictionary to hold statistics collected during the epoch.
- Type: defaultdict
_finished
A flag indicating if the reporting for this key is finished.
- Type: bool
total_count
The total number of iterations over all epochs.
- Type: int
count
The number of iterations for the current epoch.
- Type: int
_seen_keys_in_the_step
A set to track keys seen in the current step.
Type: set
Parameters:
- key (str) – The reporting key.
- epoch (int) – The current epoch number.
- total_count (int) – The total count of iterations across epochs.
get_total_count()
Returns the number of iterations over all epochs.
get_epoch()
Returns the current epoch number.
next()
Closes up the current step and resets the state for the next step.
register(stats, weight=None)
Registers statistics for the current step.
log_message(start=None, end=None)
Logs a message summarizing statistics.
tensorboard_add_scalar(summary_writer, start=None)
Adds scalar values to TensorBoard.
wandb_log(start=None)
Logs statistics to Weights & Biases.
finished()
Marks the reporting for this key as finished.
measure_time(name)
A context manager to measure and register time.
measure_iter_time(iterable, name)
Measures time for each iteration in the given iterable.
######################### Examples
>>> sub_reporter = SubReporter(key='train', epoch=1, total_count=0)
>>> sub_reporter.register({'loss': 0.2})
>>> print(sub_reporter.log_message())
"1epoch:train:1-1batch: loss=0.200"
finished()
This class is used in Reporter.
See the docstring of Reporter for the usage.
key
The key identifier for the reporter.
- Type: str
epoch
The current epoch number.
- Type: int
start_time
The time when the reporting started.
- Type: float
stats
A dictionary that holds the statistics for reporting.
- Type: defaultdict
_finished
A flag indicating if the reporting is finished.
- Type: bool
total_count
The total number of counts recorded.
- Type: int
count
The count of registered statistics in the current step.
- Type: int
_seen_keys_in_the_step
A set of keys that have been seen in the current step.
Type: set
Parameters:
- key (str) – The key identifier for the reporter.
- epoch (int) – The current epoch number.
- total_count (int) – The total number of counts recorded.
######################### Examples
>>> sub_reporter = SubReporter("train", 1, 0)
>>> sub_reporter.get_total_count()
0
>>> sub_reporter.register({"loss": 0.5}, weight=1.0)
>>> message = sub_reporter.log_message()
>>> print(message)
"1epoch:train:1-1batch: loss=0.500"
- Raises:RuntimeError – If an attempt is made to register statistics after finishing.
get_epoch()
Get the current epoch number.
- Returns: The current epoch number.
- Return type: int
######################### Examples
>>> reporter = Reporter(epoch=5)
>>> reporter.get_epoch()
5
get_total_count()
Returns the number of iterations over all epochs.
This method provides the total count of iterations that have been registered across all epochs within the SubReporter. It is useful for tracking the number of completed iterations during training or evaluation phases.
- Returns: The total count of iterations.
- Return type: int
######################### Examples
>>> sub_reporter = SubReporter(key='train', epoch=1, total_count=5)
>>> sub_reporter.get_total_count()
5
Generates a log message summarizing the statistics for the current epoch.
The message includes statistics collected during the current epoch, formatted as a string. The user can specify the range of batches to include in the log message.
- Parameters:
- start (int , optional) – The starting batch index for the log message. If None, defaults to 0. If negative, it is treated as an offset from the total count.
- end (int , optional) – The ending batch index for the log message. If None, defaults to the current count of batches.
- Returns: A formatted string containing the epoch, key, batch range, : and aggregated statistics.
- Return type: str
- Raises:RuntimeError – If the reporter has already finished logging.
######################### Examples
>>> sub_reporter = SubReporter("train", 1, 10)
>>> sub_reporter.register({"loss": 0.5})
>>> message = sub_reporter.log_message(0, 1)
>>> print(message) # Output: "1epoch:train:1-1batch: loss=0.500"
measure_iter_time(iterable, name: str)
Measures the time taken for each iteration over the given iterable.
This method will yield each item from the iterable while registering the time taken for each iteration using the specified name. The time taken is recorded in the statistics of the SubReporter.
- Parameters:
- iterable (iterable) – An iterable object to iterate over.
- name (str) – The name under which to register the time taken for each iteration.
- Yields: The next item from the iterable.
######################### Examples
>>> sub_reporter = SubReporter("example", 1, 0)
>>> for item in sub_reporter.measure_iter_time(range(3), "iteration_time"):
... print(item)
######### NOTE If the iterable is empty, the function will not yield any items.
- Raises:RuntimeError – If the sub-reporter has already finished.
measure_time(name: str)
Measures the time taken for a block of code to execute.
This context manager yields the start time and registers the elapsed time once the block is exited. It is particularly useful for tracking the duration of specific operations within the reporting framework.
- Parameters:name (str) – A descriptive name for the timing measurement.
- Yields:float – The start time in seconds since the epoch.
######################### Examples
>>> with sub_reporter.measure_time("data_loading"):
... load_data() # Some function to load data
>>> print(sub_reporter.stats) # Will show the recorded time under the key "data_loading"
######### NOTE Ensure that this context manager is used within a valid instance of SubReporter.
next()
This class is used in Reporter.
See the docstring of Reporter for the usage.
key
The key for the reported values.
- Type: str
epoch
The current epoch number.
- Type: int
start_time
The time when the reporting started.
- Type: float
stats
A dictionary holding the statistics collected.
- Type: defaultdict
_finished
A flag indicating whether the reporting is finished.
- Type: bool
total_count
The total number of counts across epochs.
- Type: int
count
The current count for the ongoing epoch.
- Type: int
_seen_keys_in_the_step
A set to track seen keys during the step.
Type: set
Parameters:
- key (str) – A string representing the key for the current report.
- epoch (int) – An integer representing the current epoch.
- total_count (int) – An integer representing the total count of iterations.
######################### Examples
>>> reporter = Reporter()
>>> with reporter.observe('train') as sub_reporter:
... for batch in iterator:
... stats = dict(loss=0.2)
... sub_reporter.register(stats)
######### NOTE This class facilitates the collection and logging of statistics during the training process.
register(stats: Dict[str, float | int | complex | Tensor | ndarray | Dict[str, float | int | complex | Tensor | ndarray] | None], weight: float | int | complex | Tensor | ndarray | None = None) → None
Register statistics for the current step.
This method allows you to log various statistics during a training step. If the key is being registered for the first time in the current step, it will initialize the statistics list for that key with NaN values for previous counts.
- Parameters:
- stats – A dictionary where the keys are statistic names and the values are either a numeric value or another dictionary containing additional metrics.
- weight – An optional weight for the registered statistics. If provided, it will be used in the calculation of weighted averages.
- Raises:
- RuntimeError – If the registration is attempted after the step has been marked as finished.
- RuntimeError – If a reserved key is used.
- RuntimeError – If a key is registered more than once in the same step.
######################### Examples
>>> sub_reporter = SubReporter('train', 1, 0)
>>> sub_reporter.register({'loss': 0.2, 'accuracy': 0.95})
>>> sub_reporter.register({'loss': 0.15}, weight=0.5)
######### NOTE The stats dictionary can contain nested dictionaries, allowing for more complex logging structures.
tensorboard_add_scalar(summary_writer, start: int | None = None)
Logs scalar values to TensorBoard.
This method aggregates the reported values from the current statistics and logs them to TensorBoard using the provided summary writer. The scalar values are logged with the key corresponding to the statistic name.
- Parameters:
- summary_writer – The TensorBoard summary writer used to log the scalar values.
- start – Optional; the starting index for logging. If not provided, it defaults to 0. If negative, it counts backwards from the current count.
######################### Examples
>>> from tensorboardX import SummaryWriter
>>> writer = SummaryWriter()
>>> sub_reporter.tensorboard_add_scalar(writer)
- Raises:AssertionError – If the lengths of the statistics list do not match the current count.
######### NOTE This method assumes that the statistics have been registered prior to logging. Ensure that the register method is called to populate the statistics before invoking this method.
wandb_log(start: int | None = None)
Logs the current statistics to Weights & Biases (wandb).
This method aggregates the statistics collected during the reporting period and logs them to the Weights & Biases dashboard. The statistics are prefixed based on their type (e.g., training, validation) and the total iteration count is also included in the log.
- Parameters:start (int , optional) – The starting index for the statistics to log. If None, logging starts from index 0. If negative, it counts backwards from the current count.
######################### Examples
>>> sub_reporter = SubReporter('train', epoch=1, total_count=10)
>>> # Register some statistics
>>> sub_reporter.register({'loss': 0.5, 'accuracy': 0.8})
>>> # Log statistics to wandb
>>> sub_reporter.wandb_log()
- Raises:
- AssertionError – If the length of the stats list does not match the
- current count. –