espnet2.torch_utils.device_funcs.force_gatherable
Less than 1 minute
espnet2.torch_utils.device_funcs.force_gatherable
espnet2.torch_utils.device_funcs.force_gatherable(data, device)
Change object to gatherable in torch.nn.DataParallel recursively.
This function modifies the input data structure to ensure that it is suitable for use with torch.nn.DataParallel. It will convert numerical values (integers and floats) into torch.Tensor objects and will ensure that tensors are moved to the specified device. The resulting structure must conform to the requirements of DataParallel, which include being a torch.cuda.Tensor and having at least one dimension.
- Parameters:
- data – The input data, which can be a tensor, list, tuple, set, dictionary, or a numpy array.
- device – The target device (e.g., ‘cuda:0’ or ‘cpu’) to which the tensors should be moved.
- Returns: A structure of the same type as data, with all applicable elements converted to tensors and moved to the specified device.
- Raises:
- UserWarning – If an element is of a type that may not be gatherable
- by DataParallel. –
Examples
>>> import torch
>>> data = [1.0, 2.0, 3.0]
>>> device = 'cuda:0' # Example device
>>> gatherable_data = force_gatherable(data, device)
>>> print(gatherable_data)
tensor([1., 2., 3.], device='cuda:0')
>>> data_dict = {'a': 1, 'b': 2.0}
>>> gatherable_data_dict = force_gatherable(data_dict, device)
>>> print(gatherable_data_dict)
{'a': tensor([1], device='cuda:0'), 'b': tensor([2.], device='cuda:0')}