espnet2.enh.layers.complex_utils.stack
Less than 1 minute
espnet2.enh.layers.complex_utils.stack
espnet2.enh.layers.complex_utils.stack(seq: Sequence[ComplexTensor | Tensor], *args, **kwargs)
Stack tensors along a new dimension.
This function takes a sequence of tensors and stacks them along a new dimension. It can handle both PyTorch tensors and ComplexTensor objects.
- Parameters:
- seq (Sequence *[*Union *[*ComplexTensor , torch.Tensor ] ]) – A sequence of tensors to be stacked. Must be either a list or tuple.
- *args – Additional arguments passed to the stacking function.
- **kwargs – Additional keyword arguments passed to the stacking function.
- Returns: A new tensor formed by stacking the input tensors along a new dimension.
- Return type: Union[ComplexTensor, torch.Tensor]
- Raises:TypeError – If the input seq is not a list or tuple.
Examples
>>> import torch
>>> from torch_complex import ComplexTensor
>>> a = torch.tensor([[1, 2], [3, 4]])
>>> b = torch.tensor([[5, 6], [7, 8]])
>>> stacked_tensor = stack([a, b])
>>> print(stacked_tensor)
tensor([[[1, 2],
[3, 4]],
[[5, 6], : [7, 8]]])
>>> c = ComplexTensor(torch.tensor([[1, 2]]), torch.tensor([[3, 4]]))
>>> d = ComplexTensor(torch.tensor([[5, 6]]), torch.tensor([[7, 8]]))
>>> stacked_complex = stack([c, d])
>>> print(stacked_complex)
ComplexTensor(real=tensor([[[1, 2]],
[[5, 6]]]),
imag=tensor([[[3, 4]],
[[7, 8]]]))