espnet2.enh.layers.complex_utils.einsum
Less than 1 minute
espnet2.enh.layers.complex_utils.einsum
espnet2.enh.layers.complex_utils.einsum(equation, *operands)
Perform Einstein summation convention on the input tensors.
This function computes the Einstein summation of the provided operands based on the specified equation string. The operands can be either torch.Tensor or ComplexTensor, but mixing them is not allowed. The function handles both real and complex tensor operations.
- Parameters:
- equation (str) – A string representing the Einstein summation convention. For example, ‘ij,jk->ik’ computes the matrix product of two tensors.
- *operands (Union *[*torch.Tensor , ComplexTensor ]) – One or two tensors to be operated on according to the equation. If one tensor is provided, it can be a tuple or list of tensors.
- Returns: The result of the Einstein summation operation, which can be either a real tensor or a complex tensor depending on the input types.
- Return type: Union[torch.Tensor, ComplexTensor]
- Raises:
- ValueError – If the number of operands is not 1 or 2, or if there
- is a mix of tensor types (complex and real)****. –
NOTE
Do not mix ComplexTensor and torch.complex in the input! Until PyTorch 1.9.0, torch.einsum does not support mixed input with complex and real tensors.
Examples
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> einsum('ij,jk->ik', a, b)
tensor([[...], [...]])
>>> c = ComplexTensor(torch.randn(2, 3), torch.randn(2, 3))
>>> d = ComplexTensor(torch.randn(3, 4), torch.randn(3, 4))
>>> einsum('ij,jk->ik', c, d)
ComplexTensor(real=[...], imag=[...])