espnet2.enh.layers.complex_utils.matmul
Less than 1 minute
espnet2.enh.layers.complex_utils.matmul
espnet2.enh.layers.complex_utils.matmul(a: Tensor | ComplexTensor, b: Tensor | ComplexTensor) → Tensor | ComplexTensor
Perform matrix multiplication on two tensors.
This function computes the matrix product of two input tensors, a and b. It supports both real and complex tensors. If either a or b is a ComplexTensor, the function will use the torch_complex library for the multiplication. If both tensors are real, the standard PyTorch torch.matmul function will be used.
NOTE
Do not mix ComplexTensor and torch.complex in the input tensors. Until PyTorch 1.9.0, torch.matmul does not support multiplication between complex and real tensors.
- Parameters:
- a – A tensor of type torch.Tensor or ComplexTensor.
- b – A tensor of type torch.Tensor or ComplexTensor.
- Returns: A tensor of the same type as the inputs, which is the result of the matrix multiplication.
- Raises:
- ValueError – If both inputs are not of type torch.Tensor or
- ComplexTensor –
Examples
>>> import torch
>>> from torch_complex.tensor import ComplexTensor
>>> a = torch.tensor([[1, 2], [3, 4]])
>>> b = torch.tensor([[5, 6], [7, 8]])
>>> result = matmul(a, b)
>>> print(result)
tensor([[19, 22],
[43, 50]])
>>> a_complex = ComplexTensor(torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[5, 6], [7, 8]]))
>>> b_complex = ComplexTensor(torch.tensor([[1, 0], [0, 1]]),
... torch.tensor([[0, 1], [1, 0]]))
>>> result_complex = matmul(a_complex, b_complex)
>>> print(result_complex)
ComplexTensor(
tensor([[ 5, 6],
[ 43, 50]]),
tensor([[ 6, 5],
[ 50, 43]])
)