espnet2.gan_codec.shared.quantizer.modules.distrib.average_metrics
Less than 1 minute
espnet2.gan_codec.shared.quantizer.modules.distrib.average_metrics
espnet2.gan_codec.shared.quantizer.modules.distrib.average_metrics(metrics: Dict[str, float], count=1.0)
Average a dictionary of metrics across all workers, using the optional count
as an unnormalized weight.
This function is designed to be used in a distributed training context, where multiple workers compute metrics that need to be averaged. It takes into account the count parameter to provide weighted averaging.
- Parameters:
- metrics (Dict *[*str , float ]) – A dictionary containing metric names as keys and their corresponding values as floats.
- count (float , optional) – An optional weight for the metrics. Defaults to 1.0.
- Returns: A new dictionary with the same keys as metrics, but : with values averaged across all workers.
- Return type: Dict[str, float]
Examples
>>> metrics = {'accuracy': 0.8, 'loss': 0.2}
>>> averaged_metrics = average_metrics(metrics, count=2.0)
>>> print(averaged_metrics)
{'accuracy': 0.8, 'loss': 0.2}
NOTE
This function assumes that the PyTorch distributed environment has been initialized. If it is not, the original metrics dictionary will be returned unchanged.