espnet2.enh.layers.tcn.DepthwiseSeparableConv
espnet2.enh.layers.tcn.DepthwiseSeparableConv
class espnet2.enh.layers.tcn.DepthwiseSeparableConv(in_channels, out_channels, skip_channels, kernel_size, stride, padding, dilation, norm_type='gLN', causal=False)
Bases: Module
Depthwise Separable Convolution Layer.
This module implements depthwise separable convolution, which consists of a depthwise convolution followed by a pointwise convolution. The depthwise convolution applies a single filter to each input channel, while the pointwise convolution combines the outputs of the depthwise convolution.
skip_conv
A convolutional layer for skip connections if skip_channels is not None.
Type: nn.Conv1d or None
Parameters:
- in_channels (int) – Number of input channels.
- out_channels (int) – Number of output channels.
- skip_channels (int or None) – Number of channels for skip connection.
- kernel_size (int) – Size of the convolution kernel.
- stride (int) – Stride of the convolution.
- padding (int) – Padding added to both sides of the input.
- dilation (int) – Dilation factor for the convolution.
- norm_type (str) – Type of normalization to apply (e.g., ‘gLN’, ‘cLN’).
- causal (bool) – Whether to use causal convolution.
Returns: Output tensor after pointwise convolution. skip_out (torch.Tensor or None): Output tensor for skip connection
if skip_channels is not None.
Return type: res_out (torch.Tensor)
####### Examples
>>> conv = DepthwiseSeparableConv(
... in_channels=16,
... out_channels=32,
... skip_channels=8,
... kernel_size=3,
... stride=1,
... padding=1,
... dilation=1,
... norm_type='gLN',
... causal=False
... )
>>> x = torch.randn(10, 16, 50) # Batch size of 10, 16 channels, length 50
>>> res_out, skip_out = conv(x)
>>> res_out.shape
torch.Size([10, 32, 50])
>>> skip_out.shape
torch.Size([10, 8, 50])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Perform the forward pass of the temporal convolutional network.
This method takes an input tensor representing a mixture of audio signals and processes it through the network to estimate masks for separating the individual audio sources.
- Parameters:mixture_w – A tensor of shape [M, N, K], where: M (int): The batch size. N (int): The number of input channels (filters). K (int): The length of the input sequences.
- Returns: A tensor of shape [M, C, N, K], representing the estimated masks for the C speakers, where:
C (int): The number of speakers.
- Return type: est_mask
- Raises:ValueError – If an unsupported mask non-linear function is specified.
####### Examples
>>> model = TemporalConvNet(N=64, B=32, H=128, P=3, X=4, R=2, C=2)
>>> mixture = torch.randn(10, 64, 100) # Example input tensor
>>> masks = model.forward(mixture)
>>> print(masks.shape) # Output: torch.Size([10, 2, 64, 100])