espnet2.asr.transducer.rnnt_multi_blank.utils.rnnt_helper.flatten_tensor
Less than 1 minute
espnet2.asr.transducer.rnnt_multi_blank.utils.rnnt_helper.flatten_tensor
espnet2.asr.transducer.rnnt_multi_blank.utils.rnnt_helper.flatten_tensor(x: Tensor)
Flatten a multi-dimensional tensor into a one-dimensional tensor.
This function takes a tensor as input and reshapes it into a one-dimensional tensor while also returning its original shape. This is useful in scenarios where you need to process data in a flattened format, such as during certain operations in neural networks.
- Parameters:x (torch.Tensor) – The input tensor to be flattened. It can be of any shape.
- Returns: A tuple containing the flattened tensor and its original shape. The first element is the flattened tensor, and the second element is a tuple representing the original dimensions.
- Return type: Tuple[torch.Tensor, Tuple[int, …]]
Examples
>>> import torch
>>> tensor = torch.tensor([[1, 2], [3, 4]])
>>> flattened_tensor, original_shape = flatten_tensor(tensor)
>>> print(flattened_tensor)
tensor([1, 2, 3, 4])
>>> print(original_shape)
(2, 2)
NOTE
The input tensor should be a PyTorch tensor. The function modifies the view of the tensor but does not alter the underlying data.