espnet2.spk.encoder.ska_tdnn_encoder.ResBlock
espnet2.spk.encoder.ska_tdnn_encoder.ResBlock
class espnet2.spk.encoder.ska_tdnn_encoder.ResBlock(inplanes: int, planes: int, stride: int = 1, reduction: int = 8, skfwse_freq: int = 40, skcwse_channel: int = 128)
Bases: Module
Residual Block with Selective Kernel Attention.
This class implements a residual block that incorporates selective kernel attention mechanisms for enhancing feature representation in deep learning models. It consists of a convolutional layer followed by batch normalization, ReLU activation, and two types of selective kernel attention: forward and channel-wise.
- Parameters:
- inplanes (int) – Number of input channels.
- planes (int) – Number of output channels.
- stride (int , optional) – Stride of the convolution. Default is 1.
- reduction (int , optional) – Reduction ratio for the attention mechanism. Default is 8.
- skfwse_freq (int , optional) – Frequency parameter for forward selective kernel attention. Default is 40.
- skcwse_channel (int , optional) – Number of channels for channel-wise selective kernel attention. Default is 128.
conv1
Convolutional layer.
- Type: nn.Conv2d
bn1
Batch normalization layer.
- Type: nn.BatchNorm2d
relu
ReLU activation function.
- Type: nn.ReLU
skfwse
Forward selective kernel attention module.
- Type:fwSKAttention
skcwse
Channel-wise selective kernel attention module.
- Type:cwSKAttention
stride
Stride of the convolution.
Type: int
Returns: The output of the forward pass.
Return type: Tensor
####### Examples
>>> block = ResBlock(inplanes=64, planes=128)
>>> input_tensor = torch.randn(1, 64, 32, 32) # Batch size 1, 64 channels
>>> output_tensor = block(input_tensor)
>>> print(output_tensor.shape)
torch.Size([1, 128, 32, 32])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Forward function for the ResBlock module.
This method takes an input tensor x, applies a series of convolutional, batch normalization, and activation operations, and then adds a residual connection from the input to the output.
- Parameters:x (torch.Tensor) – Input tensor of shape (B, C, H, W), where B is the batch size, C is the number of channels, H is the height, and W is the width.
- Returns: Output tensor of the same shape as input x.
- Return type: torch.Tensor
####### Examples
>>> res_block = ResBlock(inplanes=64, planes=128)
>>> input_tensor = torch.randn(32, 64, 16, 16) # Batch of 32
>>> output_tensor = res_block(input_tensor)
>>> output_tensor.shape
torch.Size([32, 128, 16, 16])