espnet2.enh.layers.tcn.Chomp1d
espnet2.enh.layers.tcn.Chomp1d
class espnet2.enh.layers.tcn.Chomp1d(chomp_size)
Bases: Module
To ensure the output length is the same as the input.
This module is designed to remove a specific number of elements from the end of the input tensor along the last dimension, ensuring that the output length matches the input length minus the specified chomp size.
chomp_size
The number of elements to remove from the end of the input tensor.
Type: int
Parameters:chomp_size (int) – The number of elements to be removed from the end of the input tensor.
Returns: A tensor of shape [M, H, K] where K is the original : length minus chomp_size.
Return type: torch.Tensor
####### Examples
>>> chomp = Chomp1d(chomp_size=2)
>>> input_tensor = torch.randn(1, 3, 10) # [M, H, Kpad]
>>> output_tensor = chomp(input_tensor)
>>> output_tensor.shape
torch.Size([1, 3, 8]) # [M, H, K] after chomp
NOTE
The input tensor is expected to have at least chomp_size elements in the last dimension. If the input tensor’s last dimension is smaller than chomp_size, it will raise an error.
- Raises:IndexError – If the input tensor’s last dimension is less than chomp_size.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Keep this API same with TasNet.
This method processes the input mixture of audio signals and estimates the masks for each speaker. It performs the forward pass through the temporal convolutional network defined in the TemporalConvNet class.
- Parameters:mixture_w – A tensor of shape [M, N, K], where: M (int): Batch size. N (int): Number of input channels (filters). K (int): Length of the input sequence.
- Returns: A tensor of shape [M, C, N, K], where: : C (int): Number of speakers. N (int): Number of output channels (filters). K (int): Length of the output sequence.
- Return type: est_mask
- Raises:ValueError – If the mask non-linear function specified is unsupported.
####### Examples
>>> model = TemporalConvNet(N=16, B=4, H=8, P=3, X=2, R=2, C=2)
>>> mixture = torch.randn(8, 16, 100) # Example input
>>> estimated_mask = model.forward(mixture)
>>> print(estimated_mask.shape)
torch.Size([8, 2, 16, 100]) # Output shape
NOTE
The forward pass includes the bottleneck layer, the temporal convolutional blocks, and the mask generation layers. The output mask is generated using the specified non-linear activation function.