espnet2.gan_codec.shared.quantizer.modules.core_vq.ema_inplace
Less than 1 minute
espnet2.gan_codec.shared.quantizer.modules.core_vq.ema_inplace
espnet2.gan_codec.shared.quantizer.modules.core_vq.ema_inplace(moving_avg, new, decay: float)
Update the moving average in-place using exponential decay.
This function updates a moving average tensor in-place by applying an exponential decay to the existing moving average and adding a new value. It is useful for maintaining a running average that gives more weight to recent values.
- Parameters:
- moving_avg (torch.Tensor) – The tensor containing the current moving average, which will be updated in-place.
- new (torch.Tensor) – The new value to be incorporated into the moving average.
- decay (float) – The decay factor used to compute the weighted average. It should be in the range [0, 1), where values closer to 1 give more weight to the previous average and values closer to 0 give more weight to the new value.
Examples
>>> moving_avg = torch.tensor([0.0])
>>> new_value = torch.tensor([1.0])
>>> decay_factor = 0.9
>>> ema_inplace(moving_avg, new_value, decay_factor)
>>> print(moving_avg) # Output: tensor([0.1])
>>> moving_avg = torch.tensor([0.5])
>>> new_value = torch.tensor([2.0])
>>> decay_factor = 0.8
>>> ema_inplace(moving_avg, new_value, decay_factor)
>>> print(moving_avg) # Output: tensor([1.1])