espnet2.gan_codec.shared.loss.loss_balancer.Balancer
espnet2.gan_codec.shared.loss.loss_balancer.Balancer
class espnet2.gan_codec.shared.loss.loss_balancer.Balancer(total_norm: float = 1.0, ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12)
Bases: object
Balancer is a class that normalizes losses to ensure stable training in
neural networks by balancing gradients. It uses an Exponential Moving Average (EMA) to keep track of the average norms of gradients across batches.
per_batch_item
If True, calculates norms for each batch item separately. If False, computes a single norm for the entire batch.
- Type: bool
total_norm
The target norm for the total gradient.
- Type: float
averager
An instance of the EMA class for computing moving averages of norms.
- Type:EMA
epsilon
A small value to prevent division by zero.
Type: float
Parameters:
- total_norm (float) – The target total norm for the gradients. Default is 1.0.
- ema_decay (float) – The decay factor for the EMA. Default is 0.999.
- per_batch_item (bool) – Flag to determine if norms are calculated per batch item. Default is True.
- epsilon (float) – A small constant added to the denominator for numerical stability. Default is 1e-12.
Returns: A tuple containing two dictionaries. The first dictionary contains the new normalized losses, and the second contains the computed statistics, including norms.
Return type: Tuple[Dict[str, torch.Tensor], Dict[str, float]]
####### Examples
>>> balancer = Balancer(total_norm=1.0, ema_decay=0.99)
>>> losses = {'loss1': torch.tensor(0.5), 'loss2': torch.tensor(1.0)}
>>> input_tensor = torch.randn(10, 3)
>>> new_losses, stats = balancer(losses, input_tensor)
>>> print(new_losses)
>>> print(stats)
NOTE
The class assumes that losses are provided as a dictionary, where each loss corresponds to a tensor. The input tensor is used to compute the gradients with respect to the losses.
property metrics
Balancer class for normalizing loss gradients during training.
This class utilizes an Exponential Moving Average (EMA) to compute the normalization factors for loss gradients, allowing for better stability and convergence during the training of neural networks. The class supports per-batch normalization or global normalization across the entire batch.
metrics
A dictionary containing the computed metrics.
Type: dict
Parameters:
- total_norm (float) – The desired total norm for normalization. Default is 1.0.
- ema_decay (float) – The decay rate for the exponential moving average. Default is 0.999.
- per_batch_item (bool) – If True, computes the norm for each item in the batch individually. Default is True.
- epsilon (float) – A small constant added to prevent division by zero. Default is 1e-12.
####### Examples
balancer = Balancer(total_norm=1.0, ema_decay=0.99) losses = {“loss1”: torch.tensor(0.5), “loss2”: torch.tensor(0.3)} input_tensor = torch.randn(10, 3) new_losses, stats = balancer(losses, input_tensor)
- Raises:ValueError – If the input tensor is empty or if the losses dictionary is empty.