espnet2.torch_utils.recursive_op.recursive_sum
espnet2.torch_utils.recursive_op.recursive_sum
espnet2.torch_utils.recursive_op.recursive_sum(obj, weight: Tensor, distributed: bool = False)
Recursively computes the weighted sum of elements in a nested structure.
This function can handle various types of nested structures, including tuples, lists, and dictionaries, as well as PyTorch tensors. It applies a weight tensor to the elements and can perform operations in a distributed setting.
espnet2.torch_utils.recursive_op.obj
The input object which can be a nested structure containing tensors.
- Type: Union[tuple, list, dict, torch.Tensor, None]
espnet2.torch_utils.recursive_op.weight
A 1D tensor of weights to apply to the elements in obj.
- Type: torch.Tensor
espnet2.torch_utils.recursive_op.distributed
If True, perform distributed summation.
Type: bool
Parameters:
- obj – The input object (tuple, list, dict, or tensor).
- weight – A 1D tensor of weights.
- distributed – A boolean indicating whether to perform distributed operations.
Returns: The weighted sum of the elements in the input object.
Return type: Union[tuple, list, dict, torch.Tensor, None]
Raises:ValueError – If the input object is of an unsupported type or if the dimensions of the tensors do not match.
Examples
>>> import torch
>>> weights = torch.tensor([0.1, 0.2, 0.3])
>>> tensors = [torch.tensor([1.0]), torch.tensor([2.0]), torch.tensor([3.0])]
>>> recursive_sum(tensors, weights)
tensor(1.4)
>>> weights = torch.tensor([1.0, 1.0])
>>> data = {'a': torch.tensor([1.0]), 'b': torch.tensor([2.0])}
>>> recursive_sum(data, weights)
{'a': tensor(1.0), 'b': tensor(2.0)}