espnet2.spk.encoder.ska_tdnn_encoder.SEModule
espnet2.spk.encoder.ska_tdnn_encoder.SEModule
class espnet2.spk.encoder.ska_tdnn_encoder.SEModule(channels, bottleneck=128)
Bases: Module
Squeeze-and-Excitation (SE) Module.
This module implements the Squeeze-and-Excitation block, which adaptively recalibrates channel-wise feature responses by explicitly modeling the interdependencies between channels. It uses global information to enhance the representational power of the network.
se
The sequential container that includes:
- AdaptiveAvgPool1d: Applies adaptive average pooling.
- Conv1d: 1D convolution layer to reduce channel dimensions.
- ReLU: Activation function.
- BatchNorm1d: Batch normalization layer.
- Conv1d: 1D convolution layer to restore original channel dimensions.
- Sigmoid: Activation function to produce channel weights.
Type: nn.Sequential
Parameters:
- channels (int) – Number of input channels.
- bottleneck (int , optional) – Number of channels in the bottleneck layer. Default is 128.
Returns: The output tensor after applying the squeeze-and-excitation operation.
Return type: Tensor
####### Examples
>>> se_module = SEModule(channels=64, bottleneck=16)
>>> input_tensor = torch.randn(32, 64, 100) # (batch_size, channels, time)
>>> output_tensor = se_module(input_tensor)
>>> output_tensor.shape
torch.Size([32, 64, 100])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input)
Forward pass of the SEModule.
This method applies the squeeze-and-excitation (SE) operation to the input tensor. The SE module computes a channel-wise weighting of the input tensor to enhance the representation of important features and suppress less important ones.
- Parameters:input (torch.Tensor) – Input tensor of shape (B, C, T), where B is the batch size, C is the number of channels, and T is the length of the sequence.
- Returns: Output tensor of the same shape as the input, after applying : the SE operation.
- Return type: torch.Tensor
####### Examples
>>> se_module = SEModule(channels=64)
>>> input_tensor = torch.randn(8, 64, 100) # Batch size 8, 64 channels, length 100
>>> output_tensor = se_module(input_tensor)
>>> print(output_tensor.shape)
torch.Size([8, 64, 100])