espnet2.torch_utils.recursive_op.recursive_average
espnet2.torch_utils.recursive_op.recursive_average
espnet2.torch_utils.recursive_op.recursive_average(obj, weight: Tensor, distributed: bool = False)
Calculates the weighted average of a nested structure of tensors.
This function recursively computes the sum of the tensors in the given object, normalizes it by the sum of the weights, and optionally handles distributed settings. It is designed to work with nested structures such as lists, tuples, and dictionaries containing PyTorch tensors.
espnet2.torch_utils.recursive_op.obj
A nested structure (list, tuple, dict) or tensor from which to compute the average.
espnet2.torch_utils.recursive_op.weight
A 1D tensor containing weights corresponding to the elements in obj.
espnet2.torch_utils.recursive_op.distributed
A boolean indicating whether to perform the operation in a distributed manner.
- Parameters:
- obj (Union *[*torch.Tensor , List , Tuple , Dict ]) – The input object for which the weighted average is calculated.
- weight (torch.Tensor) – A 1D tensor of weights that matches the size of the tensors in obj.
- distributed (bool , optional) – Flag to indicate if the operation should be performed in a distributed setting. Defaults to False.
- Returns: A tuple containing the weighted average of the input object and the total weight.
- Return type: Tuple[Union[torch.Tensor, List, Dict], torch.Tensor]
- Raises:ValueError – If the input object obj is not of a valid type (tensor, list, tuple, or dict) or if the dimensions of the tensors do not match.
Examples
>>> import torch
>>> weights = torch.tensor([0.2, 0.3, 0.5])
>>> tensors = [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)]
>>> average, total_weight = recursive_average(tensors, weights)
>>> print(average) # Output: tensor(2.0)
>>> print(total_weight) # Output: tensor(1.0)
NOTE
This function is primarily intended for use in machine learning contexts, where weighted averages of loss values or predictions are common.