espnet2.enh.layers.complex_utils.trace
Less than 1 minute
espnet2.enh.layers.complex_utils.trace
espnet2.enh.layers.complex_utils.trace(a: Tensor | ComplexTensor)
Compute the trace of a tensor.
The trace of a matrix is defined as the sum of the elements on the main diagonal. This function can handle both standard PyTorch tensors and ComplexTensor types. For versions of PyTorch prior to 1.9.0, it uses the FC.trace() function as a fallback, since torch.trace() does not support batch processing.
- Parameters:a – A tensor or a ComplexTensor for which the trace is to be computed.
- Returns: A scalar tensor representing the trace of the input tensor.
- Raises:TypeError – If the input is not a tensor or ComplexTensor.
Examples
>>> import torch
>>> from torch_complex.tensor import ComplexTensor
>>> a = torch.tensor([[1, 2], [3, 4]])
>>> trace(a)
tensor(5)
>>> c = ComplexTensor(torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[5, 6], [7, 8]]))
>>> trace(c)
ComplexTensor(real=tensor(5), imag=tensor(21))
NOTE
This function is intended for use with 2D tensors, typically representing matrices.