espnet2.enh.layers.ncsnpp_utils.layers.AttnBlock
espnet2.enh.layers.ncsnpp_utils.layers.AttnBlock
class espnet2.enh.layers.ncsnpp_utils.layers.AttnBlock(channels)
Bases: Module
Channel-wise self-attention block.
This module implements a channel-wise self-attention mechanism using normalization and learned linear transformations. It applies Group Normalization followed by multiple NIN (Network in Network) layers to compute the attention weights and the output.
GroupNorm_0
Group normalization layer for input.
- Type: nn.GroupNorm
NIN_0
First NIN layer for query generation.
- Type:NIN
NIN_1
Second NIN layer for key generation.
- Type:NIN
NIN_2
Third NIN layer for value generation.
- Type:NIN
NIN_3
Fourth NIN layer for output transformation.
Type:NIN
Parameters:channels (int) – Number of input channels.
Returns: The output tensor after applying self-attention.
Return type: torch.Tensor
####### Examples
>>> attn_block = AttnBlock(channels=64)
>>> input_tensor = torch.randn(1, 64, 32, 32) # Batch size 1, 64 channels
>>> output_tensor = attn_block(input_tensor)
>>> output_tensor.shape
torch.Size([1, 64, 32, 32]) # Output has the same shape as input
NOTE
This block assumes that the input tensor has the shape (B, C, H, W) where B is the batch size, C is the number of channels, and H and W are the height and width of the feature map.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Channel-wise self-attention block.
This module implements a channel-wise self-attention mechanism, which allows the model to focus on different parts of the input feature maps based on their relevance. The block utilizes group normalization and learnable linear layers (NIN) for transforming the input features.
GroupNorm_0
Group normalization layer.
NIN_0
First learnable linear layer for query transformation.
NIN_1
Second learnable linear layer for key transformation.
NIN_2
Third learnable linear layer for value transformation.
NIN_3
Fourth learnable linear layer for output transformation.
- Parameters:channels (int) – Number of input channels.
- Returns: The output tensor after applying the attention mechanism.
- Return type: Tensor
####### Examples
>>> attn_block = AttnBlock(channels=64)
>>> input_tensor = torch.randn(8, 64, 32, 32) # (batch_size, channels, height, width)
>>> output_tensor = attn_block(input_tensor)
>>> output_tensor.shape
torch.Size([8, 64, 32, 32])
NOTE
The attention weights are computed using the scaled dot-product attention mechanism, followed by a softmax normalization.
- Raises:ValueError – If the input tensor does not have the expected shape.