espnet2.gan_codec.shared.quantizer.modules.distrib.broadcast_tensors
Less than 1 minute
espnet2.gan_codec.shared.quantizer.modules.distrib.broadcast_tensors
espnet2.gan_codec.shared.quantizer.modules.distrib.broadcast_tensors(tensors: Iterable[Tensor], src: int = 0)
Broadcast the tensors from the given parameters to all workers.
This function ensures that all workers in a distributed setting have the same model parameters to start with. It checks if the distributed environment is initialized and filters the tensors to only include complex or floating-point types. It also verifies that all workers have the same number of parameters to prevent deadlocks during the broadcast operation.
- Parameters:
- tensors (Iterable *[*torch.Tensor ]) – An iterable of tensors to be broadcasted.
- src (int , optional) – The source rank from which to broadcast the tensors. Defaults to 0.
- Returns: This function does not return a value.
- Return type: None
- Raises:RuntimeError – If there is a mismatch in the number of parameters across workers.
Examples
>>> if is_distributed():
>>> tensor1 = torch.tensor([1.0], requires_grad=True)
>>> tensor2 = torch.tensor([2.0], requires_grad=True)
>>> broadcast_tensors([tensor1, tensor2], src=0)
>>> # All workers will now have tensor1 and tensor2 with the same values.
NOTE
This function is intended for use in a PyTorch distributed environment and requires the torch.distributed package to be initialized.