espnet2.enh.layers.complex_utils.complex_norm
Less than 1 minute
espnet2.enh.layers.complex_utils.complex_norm
espnet2.enh.layers.complex_utils.complex_norm(c: Tensor | ComplexTensor, dim=-1, keepdim=False) → Tensor
Compute the norm of a complex tensor.
This function calculates the norm of a complex tensor along a specified dimension. If the input tensor is not complex, a TypeError is raised. The function can handle both PyTorch’s native complex tensors and custom ComplexTensor objects.
- Parameters:
- c (Union *[*torch.Tensor , ComplexTensor ]) – The input tensor for which to compute the norm. It must be either a complex tensor or a compatible type.
- dim (int , optional) – The dimension along which to compute the norm. Default is -1, which means the last dimension.
- keepdim (bool , optional) – Whether to retain the dimensions of the input tensor in the output. Default is False.
- Returns: The computed norm of the input tensor.
- Return type: torch.Tensor
- Raises:TypeError – If the input tensor is not a complex tensor.
Examples
>>> import torch
>>> from torch_complex import ComplexTensor
>>> c_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.complex64)
>>> complex_norm(c_tensor)
tensor(5.4772)
>>> c_tensor_custom = ComplexTensor(torch.tensor([1.0, 2.0]),
... torch.tensor([3.0, 4.0]))
>>> complex_norm(c_tensor_custom, dim=0)
tensor([3.1623, 4.4721])
NOTE
This function adds a small epsilon value to the result to avoid division by zero errors.