espnet2.enh.layers.dprnn.split_feature
Less than 1 minute
espnet2.enh.layers.dprnn.split_feature
espnet2.enh.layers.dprnn.split_feature(input, segment_size)
Split the input features into chunks of specified segment size.
This function takes a tensor of features and splits it into overlapping segments of a given size. It also handles padding to ensure that the segments are of the correct size and can be processed without losing information from the input tensor.
- Parameters:
- input (torch.Tensor) – The input tensor of shape (B, N, T), where B is the batch size, N is the number of features, and T is the sequence length.
- segment_size (int) – The size of each segment to split the input into.
- Returns: A tuple containing: : - A tensor of shape (B, N, K, segment_size), where K is the number of segments created from the input.
- An integer representing the number of elements that were padded at the end of the input.
- Return type: Tuple[torch.Tensor, int]
Examples
>>> input_tensor = torch.randn(2, 3, 10) # (B=2, N=3, T=10)
>>> segments, rest = split_feature(input_tensor, segment_size=4)
>>> segments.shape
torch.Size([2, 3, 6, 4]) # Example output shape with K=6 segments
NOTE
The function uses zero-padding to ensure that the input length is a multiple of the segment size before splitting.