espnet2.diar.layers.tcn_nomask.Chomp1d
espnet2.diar.layers.tcn_nomask.Chomp1d
class espnet2.diar.layers.tcn_nomask.Chomp1d(chomp_size)
Bases: Module
To ensure the output length is the same as the input.
This module removes a specified number of elements from the end of the input tensor, which is useful in convolutional architectures to maintain the desired output size after applying causal convolutions.
chomp_size
The number of elements to remove from the end of the input.
Type: int
Parameters:chomp_size (int) – The size of the chomp, i.e., the number of elements to discard from the end of the input tensor.
Returns: The output tensor with the last chomp_size elements removed.
Return type: torch.Tensor
####### Examples
>>> chomp = Chomp1d(chomp_size=2)
>>> input_tensor = torch.randn(5, 10, 20) # [M, H, Kpad]
>>> output_tensor = chomp(input_tensor)
>>> output_tensor.shape
torch.Size([5, 10, 18]) # Output shape after chomp
NOTE
This module is particularly useful when used in conjunction with depthwise separable convolutions to maintain the appropriate sequence length for subsequent layers.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
To ensure the output length is the same as the input.
This module removes a specified number of elements from the end of the input tensor along the last dimension, ensuring that the output has the same length as the original input minus the chomp_size.
chomp_size
The number of elements to remove from the end of the input tensor.
Type: int
Parameters:chomp_size – The number of elements to be removed from the end of the input tensor.
Returns: A tensor of shape [M, H, K] where K is the original length minus chomp_size.
####### Examples
>>> chomp = Chomp1d(chomp_size=2)
>>> input_tensor = torch.randn(10, 5, 20) # [M, H, Kpad]
>>> output_tensor = chomp(input_tensor)
>>> output_tensor.shape
torch.Size([10, 5, 18]) # [M, H, K]
NOTE
This module is typically used in conjunction with causal convolutions to ensure that the output length matches the expected dimensions for further processing in a network.
- Raises:ValueError – If chomp_size is negative or greater than the input length.