espnet2.gan_codec.shared.loss.loss_balancer.EMA
espnet2.gan_codec.shared.loss.loss_balancer.EMA
class espnet2.gan_codec.shared.loss.loss_balancer.EMA(ema_decay)
Bases: object
Exponential Moving Average (EMA) for tracking and smoothing statistics.
This class maintains a running exponential moving average of input statistics over time. It is particularly useful for stabilizing training in machine learning models by smoothing out the noise in the statistics collected during training.
ema_decay
The decay factor for the moving average, which determines the weight of the previous average relative to the new value.
- Type: float
cache
A dictionary that stores the current moving averages for each statistic.
Type: dict
Parameters:ema_decay (float) – The decay rate for the moving average. Should be between 0 and 1, where a value closer to 1 gives more weight to past averages.
Returns: A dictionary containing the updated moving averages for the provided : statistics.
Return type: dict
Examples
>>> ema = EMA(ema_decay=0.9)
>>> stats = {'loss': torch.tensor(0.5), 'accuracy': torch.tensor(0.8)}
>>> updated_stats = ema(stats)
>>> print(updated_stats)
{'loss': tensor(...), 'accuracy': tensor(...)}
NOTE
The method uses PyTorch’s no_grad context to prevent gradient tracking during the moving average computation.
- Raises:ValueError – If the input statistics are not in a valid format or if they do not match the expected tensor shape.