espnet2.enh.layers.complex_utils.cat
Less than 1 minute
espnet2.enh.layers.complex_utils.cat
espnet2.enh.layers.complex_utils.cat(seq: Sequence[ComplexTensor | Tensor], *args, **kwargs)
Concatenate a sequence of tensors along a specified dimension.
This function concatenates a sequence of tensors (either ComplexTensor or torch.Tensor) into a single tensor. The input sequence must be a list or tuple. If the first tensor in the sequence is a ComplexTensor, the function uses the torch_complex library’s cat method; otherwise, it falls back to PyTorch’s torch.cat.
- Parameters:
- seq (Sequence *[*Union *[*ComplexTensor , torch.Tensor ] ]) – A sequence of tensors to concatenate. Must be a list or tuple.
- *args – Additional positional arguments to pass to the concatenation function.
- **kwargs – Additional keyword arguments to pass to the concatenation function.
- Returns: A tensor resulting from concatenating the input sequence.
- Return type: Union[ComplexTensor, torch.Tensor]
- Raises:TypeError – If seq is not a list or tuple.
Examples
>>> import torch
>>> from torch_complex import ComplexTensor
>>> a = torch.tensor([[1, 2], [3, 4]])
>>> b = torch.tensor([[5, 6], [7, 8]])
>>> result = cat([a, b])
>>> print(result)
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
>>> c = ComplexTensor(torch.tensor([[1, 2]]), torch.tensor([[3, 4]]))
>>> d = ComplexTensor(torch.tensor([[5, 6]]), torch.tensor([[7, 8]]))
>>> result_complex = cat([c, d])
>>> print(result_complex)
ComplexTensor(
real=tensor([[1, 2],
[5, 6]]),
imag=tensor([[3, 4],
[7, 8]])
)