espnet2.gan_codec.shared.quantizer.modules.distrib.sync_buffer
espnet2.gan_codec.shared.quantizer.modules.distrib.sync_buffer
espnet2.gan_codec.shared.quantizer.modules.distrib.sync_buffer(buffers, average=True)
Sync gradient for buffers across distributed workers.
This function synchronizes the gradients of the provided buffers among all workers in a distributed training environment. If the average parameter is set to True, the gradients are averaged across all workers; otherwise, the buffers are broadcasted from the source worker.
espnet2.gan_codec.shared.quantizer.modules.distrib.buffers
A collection of tensors whose gradients need to be synchronized.
- Type: Iterable[torch.Tensor]
espnet2.gan_codec.shared.quantizer.modules.distrib.average
A flag indicating whether to average the gradients (True) or broadcast them (False). Defaults to True.
Type: bool
Parameters:
- buffers (Iterable *[*torch.Tensor ]) – A collection of tensors to synchronize.
- average (bool) – If True, average the gradients; if False, broadcast them.
Returns: The function operates in place on the provided buffers.
Return type: None
Raises:RuntimeError – If the operation is attempted in a non-distributed environment.
Examples
>>> import torch
>>> from some_module import sync_buffer
>>> buffer1 = torch.tensor([1.0, 2.0], requires_grad=True)
>>> buffer2 = torch.tensor([3.0, 4.0], requires_grad=True)
>>> sync_buffer([buffer1, buffer2], average=True)
>>> print(buffer1) # Buffers are averaged across workers
>>> sync_buffer([buffer1, buffer2], average=False)
>>> print(buffer1) # Buffers are broadcasted from the source worker
NOTE
This function requires that PyTorch’s distributed package is properly initialized and that the number of buffers is consistent across all workers to avoid deadlocks.