espnet2.gan_codec.shared.quantizer.modules.distrib.sync_grad
Less than 1 minute
espnet2.gan_codec.shared.quantizer.modules.distrib.sync_grad
espnet2.gan_codec.shared.quantizer.modules.distrib.sync_grad(params)
Synchronize gradients across all distributed processes.
This function serves as a simpler alternative to DistributedDataParallel, providing a straightforward method for synchronizing gradients without relying on complex mechanisms. It is especially useful for simple models where it can perform as efficiently as DistributedDataParallel. Call this function on your model parameters after invoking the backward pass to ensure that gradients are synchronized across all workers.
- Parameters:
- params (Iterable *[*torch.Tensor ]) – An iterable of model parameters
- **(**e.g. ( ) `)
- `model.parameters (from) –
- Returns: None
- Raises:
- RuntimeError – If the distributed environment is initialized but
- there is a mismatch in the number of parameters across workers. –
Examples
>>> model = MyModel()
>>> loss = compute_loss(model(input), target)
>>> loss.backward() # Compute gradients
>>> sync_grad(model.parameters()) # Synchronize gradients
NOTE
This function assumes that the distributed environment has been properly initialized using PyTorch’s distributed package.