espnet2.torch_utils.device_funcs.to_device
Less than 1 minute
espnet2.torch_utils.device_funcs.to_device
espnet2.torch_utils.device_funcs.to_device(data, device=None, dtype=None, non_blocking=False, copy=False)
Change the device of object recursively.
This function recursively moves a given data structure (which can include dictionaries, lists, tuples, NumPy arrays, and PyTorch tensors) to the specified device (e.g., CPU or GPU). It can also optionally change the data type of the tensors.
- Parameters:
- data – The input data, which can be a tensor, numpy array, list, tuple, dictionary, or a dataclass.
- device – The target device to move the data to. This can be a string (e.g., ‘cpu’, ‘cuda:0’) or a torch.device object.
- dtype – The desired data type to convert the tensors to (e.g., torch.float, torch.int). If None, the dtype will remain unchanged.
- non_blocking – If True and the source is in pinned memory, the copy will be non-blocking. Defaults to False.
- copy – If True, a new tensor will be created, and the original tensor will not be modified. Defaults to False.
- Returns: The input data, moved to the specified device with the desired dtype, if applicable.
Examples
>>> import torch
>>> data = torch.tensor([1, 2, 3])
>>> to_device(data, device='cuda:0')
tensor([1, 2, 3], device='cuda:0')
>>> data = {'a': torch.tensor([1]), 'b': [torch.tensor([2]),
... torch.tensor([3])]}
>>> to_device(data, device='cuda:0')
{'a': tensor([1], device='cuda:0'),
'b': [tensor([2], device='cuda:0'), tensor([3], device='cuda:0')]}
NOTE
If dtype is specified, conversion between int and float types is avoided to prevent unexpected behavior.
- Raises:ValueError – If an unsupported data type is provided.