espnet2.spk.encoder.ska_tdnn_encoder.SKAttentionModule
espnet2.spk.encoder.ska_tdnn_encoder.SKAttentionModule
class espnet2.spk.encoder.ska_tdnn_encoder.SKAttentionModule(channel=128, reduction=4, L=16, num_kernels=2)
Bases: Module
Selective Kernel Attention Module.
This module implements the Selective Kernel Attention mechanism, which allows the model to adaptively select kernel responses based on the importance of features. It enhances the representation capability by dynamically weighting the outputs from multiple convolutional kernels.
avg_pool
Adaptive average pooling layer.
- Type: nn.AdaptiveAvgPool1d
D
Dimension of the intermediate representation.
- Type: int
fc
Fully connected layer for dimensionality reduction.
- Type: nn.Linear
relu
ReLU activation function.
- Type: nn.ReLU
fc
List of fully connected layers for attention weights.
- Type: nn.ModuleList
softmax
Softmax layer for normalizing attention weights.
Type: nn.Softmax
Parameters:
- channel (int) – Number of input channels.
- reduction (int) – Reduction ratio for dimensionality.
- L (int) – Maximum number of kernels.
- num_kernels (int) – Number of convolutional kernels to use.
####### Examples
>>> sk_attention = SKAttentionModule(channel=128, reduction=4, L=16, num_kernels=2)
>>> input_tensor = torch.randn(32, 128, 50) # (Batch, Channels, Time)
>>> convs = [nn.Conv1d(128, 128, kernel_size=3, padding=1) for _ in range(2)]
>>> output = sk_attention(input_tensor, convs)
>>> output.shape
torch.Size([32, 128, 50])
NOTE
This module requires a list of convolutional layers to be passed during the forward call for the attention mechanism to work.
- Raises:ValueError – If the input tensor does not match the expected dimensions.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, convs)
Forward function.
The forward method processes the input tensor through a series of convolutional and activation layers, applying the SK attention mechanism for improved feature extraction.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (B, C, T) where B is the
- size (batch)
- channels (C is the number of)
- the (and T is the length of)
- sequence.
- Returns: Output tensor of shape (B, C, T) after applying the SK attention mechanism and residual connection.
- Return type: torch.Tensor
####### Examples
>>> model = SKAttentionModule(channel=128, reduction=4)
>>> input_tensor = torch.randn(10, 128, 50) # Example input
>>> output_tensor = model(input_tensor, convs) # Assuming `convs` is defined
>>> print(output_tensor.shape) # Output shape will be (10, 128, 50)