espnet2.spk.encoder.ska_tdnn_encoder.fwSKAttention
espnet2.spk.encoder.ska_tdnn_encoder.fwSKAttention
class espnet2.spk.encoder.ska_tdnn_encoder.fwSKAttention(freq=40, channel=128, kernels=[3, 5], receptive=[3, 5], dilations=[1, 1], reduction=8, groups=1, L=16)
Bases: Module
Frequency-wise Selective Kernel Attention (fwSKAttention) module.
This module applies frequency-wise selective kernel attention to the input tensor, enabling the model to focus on different frequency bands through multiple convolutional kernels. It combines the outputs from these kernels using learned attention weights.
- Parameters:
- freq (int) – The number of frequency bins in the input tensor.
- channel (int) – The number of channels in the input tensor.
- kernels (list of int) – List of kernel sizes for the convolutions.
- receptive (list of int) – List of receptive field sizes for the convolutions.
- dilations (list of int) – List of dilation rates for the convolutions.
- reduction (int) – Reduction ratio for the attention mechanism.
- groups (int) – Number of groups for group convolution.
- L (int) – Maximum value for the hidden dimension in the attention mechanism.
- Returns: The output tensor after applying the selective kernel attention.
- Return type: V (Tensor)
####### Examples
>>> model = fwSKAttention(freq=40, channel=128)
>>> input_tensor = torch.randn(8, 128, 40, 100) # (B, C, F, T)
>>> output = model(input_tensor)
>>> output.shape
torch.Size([8, 128, 40, 100])
NOTE
The input tensor should have a shape of (B, C, F, T), where B is the batch size, C is the number of channels, F is the number of frequency bins, and T is the number of time steps.
- Raises:ValueError – If the input tensor does not have the expected shape.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Implements frequency-wise selective kernel attention.
This module applies a series of convolutional layers with varying kernel sizes to the input tensor, followed by a frequency-wise attention mechanism. The attention mechanism helps the model to focus on the most relevant frequency components in the input data, which can be particularly useful in tasks like speaker verification.
- Parameters:
- freq (int) – The frequency dimension of the input. Default is 40.
- channel (int) – The number of input channels. Default is 128.
- kernels (list) – List of kernel sizes for the convolutions. Default is [3, 5].
- receptive (list) – List of receptive field sizes for the convolutions. Default is [3, 5].
- dilations (list) – List of dilation rates for the convolutions. Default is [1, 1].
- reduction (int) – The reduction ratio for the attention mechanism. Default is 8.
- groups (int) – The number of groups for grouped convolutions. Default is 1.
- L (int) – Maximum number of features after reduction. Default is 16.
- Returns: The output tensor after applying the attention mechanism.
- Return type: torch.Tensor
####### Examples
>>> attention = fwSKAttention(freq=40, channel=128)
>>> input_tensor = torch.randn(8, 128, 40, 100) # (B, C, F, T)
>>> output_tensor = attention(input_tensor)
>>> output_tensor.size()
torch.Size([8, 128, 40, 100])
NOTE
The input tensor is expected to have 4 dimensions: batch size, channels, frequency, and time.
- Raises:ValueError – If the input tensor does not have 4 dimensions.