espnet2.gan_tts.utils.get_random_segments.get_random_segments
Less than 1 minute
espnet2.gan_tts.utils.get_random_segments.get_random_segments
espnet2.gan_tts.utils.get_random_segments.get_random_segments(x: Tensor, x_lengths: Tensor, segment_size: int) → Tuple[Tensor, Tensor]
Function to get random segments from an input tensor.
This function extracts random segments of a specified size from the input tensor, ensuring that the segments do not exceed the lengths provided in the x_lengths tensor.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (B, C, T), where B is the batch size, C is the number of channels, and T is the length of the input.
- x_lengths (torch.Tensor) – Length tensor of shape (B,), indicating the valid lengths of each input tensor in the batch.
- segment_size (int) – Size of the segment to be extracted from the input tensor.
- Returns: A tuple containing: : - Tensor: Segmented tensor of shape (B, C, segment_size).
- Tensor: Start index tensor of shape (B,), indicating the starting indices of the segments in the input tensor.
- Return type: Tuple[torch.Tensor, torch.Tensor]
Examples
>>> x = torch.randn(4, 2, 10) # Example input tensor (B=4, C=2, T=10)
>>> x_lengths = torch.tensor([10, 9, 8, 7]) # Valid lengths for each input
>>> segment_size = 5
>>> segments, start_idxs = get_random_segments(x, x_lengths, segment_size)
>>> segments.shape
torch.Size([4, 2, 5])
>>> start_idxs.shape
torch.Size([4])