espnet2.enh.layers.complex_utils.solve
Less than 1 minute
espnet2.enh.layers.complex_utils.solve
espnet2.enh.layers.complex_utils.solve(b: Tensor | ComplexTensor, a: Tensor | ComplexTensor)
Solve the linear equation ax = b.
This function computes the solution of the linear matrix equation ax = b, where a is a matrix and b is a vector or matrix. It handles both torch.Tensor and ComplexTensor types.
Note that mixing ComplexTensor and torch.complex is not allowed. As of PyTorch 1.9.0, torch.solve does not support mixed inputs with complex and real tensors.
- Parameters:
- b – The right-hand side of the equation (torch.Tensor or ComplexTensor).
- a – The left-hand side of the equation (torch.Tensor or ComplexTensor).
- Returns: The solution x to the equation ax = b, which has the same type as b (torch.Tensor or ComplexTensor).
- Raises:ValueError – If the input tensors are not compatible for solving.
Examples
>>> import torch
>>> a = torch.tensor([[2, 1], [1, 3]], dtype=torch.float32)
>>> b = torch.tensor([1, 2], dtype=torch.float32)
>>> x = solve(b, a)
>>> print(x)
tensor([0.0000, 0.6667])
>>> a_complex = ComplexTensor(torch.tensor([[2, 1], [1, 3]]),
... torch.tensor([[0, 0], [0, 0]]))
>>> b_complex = ComplexTensor(torch.tensor([1, 2]),
... torch.tensor([0, 0]))
>>> x_complex = solve(b_complex, a_complex)
>>> print(x_complex)
ComplexTensor(real=tensor([0.0000, 0.6667]), imag=tensor([0, 0]))