espnet2.train.distributed_utils.resolve_distributed_mode
Less than 1 minute
espnet2.train.distributed_utils.resolve_distributed_mode
espnet2.train.distributed_utils.resolve_distributed_mode(args)
Resolve the distributed training mode based on the provided arguments.
This function sets the args.distributed attribute based on the configuration of the distributed training environment. It checks various conditions to determine if the training should run in distributed mode, such as the number of nodes, GPUs, and the launcher being used.
Parameters:args –
An object containing the training arguments, which should include:
- args.multiprocessing_distributed: A boolean indicating if
multiprocessing distributed mode is enabled.
- args.dist_world_size: An integer specifying the total number of processes participating in the job.
- args.ngpu: An integer representing the number of GPUs available.
- args.dist_rank: An optional integer representing the rank of the current process.
- args.local_rank: An optional integer representing the local rank of the current process.
- args.dist_launcher: A string indicating the launcher being used (e.g., “slurm”, “mpi”).
Raises:
- RuntimeError – If the conditions for distributed training are not met,
- such as missing required arguments or misconfiguration. –
Examples
>>> class Args:
... def __init__(self):
... self.multiprocessing_distributed = True
... self.dist_world_size = 4
... self.ngpu = 2
... self.dist_rank = None
... self.local_rank = None
... self.dist_launcher = "slurm"
...
>>> args = Args()
>>> resolve_distributed_mode(args)
>>> print(args.distributed)
True
NOTE
This function modifies the args object in place. After calling this function, args.distributed, args.local_rank, and other related attributes will be set according to the determined mode.