espnet2.enh.layers.ncsnpp_utils.layerspp.AttnBlockpp
espnet2.enh.layers.ncsnpp_utils.layerspp.AttnBlockpp
class espnet2.enh.layers.ncsnpp_utils.layerspp.AttnBlockpp(channels, skip_rescale=False, init_scale=0.0)
Bases: Module
Channel-wise self-attention block. Modified from DDPM.
This class implements a channel-wise self-attention mechanism that allows the model to focus on different parts of the input features adaptively. It is designed to work within the NCSN++ architecture, enhancing the representation capabilities of the model.
GroupNorm_0
Group normalization layer for input.
- Type: nn.GroupNorm
NIN_0
First non-linear transformation layer.
- Type:NIN
NIN_1
Second non-linear transformation layer.
- Type:NIN
NIN_2
Third non-linear transformation layer.
- Type:NIN
NIN_3
Fourth non-linear transformation layer.
- Type:NIN
skip_rescale
Flag to determine if output should be rescaled.
Type: bool
Parameters:
- channels (int) – Number of input channels.
- skip_rescale (bool , optional) – If True, rescales the output to maintain stability. Defaults to False.
- init_scale (float , optional) – Scale for initializing layers. Defaults to 0.0.
Returns: The output tensor after applying self-attention.
Return type: Tensor
Raises:ValueError – If input tensor shape is not valid.
####### Examples
>>> attn_block = AttnBlockpp(channels=64)
>>> input_tensor = torch.randn(8, 64, 32, 32) # (batch, channels, height, width)
>>> output_tensor = attn_block(input_tensor)
>>> output_tensor.shape
torch.Size([8, 64, 32, 32]) # Output shape matches input shape
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Forward pass for the AttnBlockpp module.
This method implements the forward pass for the attention block, which performs channel-wise self-attention on the input tensor. The attention mechanism utilizes query, key, and value representations derived from the input, and combines them to produce an output tensor. The output can be either the original input plus the attention output, or a rescaled version of the sum based on the skip_rescale attribute.
- Parameters:x (torch.Tensor) – Input tensor of shape (B, C, H, W), where B is the batch size, C is the number of channels, H is the height, and W is the width of the input.
- Returns: Output tensor of the same shape as the input, : representing the result of the self-attention operation.
- Return type: torch.Tensor
- Raises:ValueError – If the input tensor does not have the expected shape or if the number of channels does not match the initialized parameters.
####### Examples
>>> model = AttnBlockpp(channels=64)
>>> input_tensor = torch.randn(8, 64, 32, 32) # Batch of 8 images
>>> output_tensor = model(input_tensor)
>>> print(output_tensor.shape)
torch.Size([8, 64, 32, 32])