espnet2.enh.layers.complexnn.complex_cat
espnet2.enh.layers.complexnn.complex_cat
espnet2.enh.layers.complexnn.complex_cat(inputs, axis)
Concatenate complex-valued tensors along a specified axis.
This function takes a list of complex-valued tensors, splits each tensor into its real and imaginary parts, concatenates these parts along the specified axis, and returns a single complex-valued tensor.
espnet2.enh.layers.complexnn.inputs
A list of complex-valued tensors, where each tensor is expected to have an even number of channels (real and imaginary parts).
- Type: list
espnet2.enh.layers.complexnn.axis
The axis along which to concatenate the real and imaginary parts.
Type: int
Parameters:
- inputs (list of torch.Tensor) – A list containing complex-valued tensors.
- axis (int) – The axis along which to concatenate the real and imaginary parts.
Returns: A single complex-valued tensor formed by concatenating the : real and imaginary parts of the input tensors along the specified axis.
Return type: torch.Tensor
Examples
>>> import torch
>>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) # Real part
>>> b = torch.tensor([[5.0, 6.0], [7.0, 8.0]]) # Imaginary part
>>> complex_tensor = torch.cat([a, b], dim=-1) # Combine real and imag
>>> result = complex_cat([complex_tensor, complex_tensor], axis=0)
>>> print(result)
tensor([[1., 2., 5., 6.],
[3., 4., 7., 8.],
[1., 2., 5., 6.],
[3., 4., 7., 8.]])
NOTE
The input tensors must have the same shape along all dimensions except for the specified axis.