espnet2.enh.layers.ncsnpp_utils.layerspp.ResnetBlockDDPMpp
espnet2.enh.layers.ncsnpp_utils.layerspp.ResnetBlockDDPMpp
class espnet2.enh.layers.ncsnpp_utils.layerspp.ResnetBlockDDPMpp(act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1, skip_rescale=False, init_scale=0.0)
Bases: Module
ResBlock adapted from DDPM.
This class implements a residual block that is adapted from Denoising Diffusion Probabilistic Models (DDPM). It includes normalization, convolutional layers, and the option to incorporate temporal embeddings for conditional processing. The block supports skip connections and can rescale the output.
act
Activation function to be applied in the block.
out_ch
Number of output channels for the convolutional layers.
conv_shortcut
Whether to use a convolutional shortcut.
skip_rescale
Whether to rescale the skip connection output.
GroupNorm_0
First group normalization layer.
Conv_0
First convolutional layer.
Dense_0
Linear layer for processing temporal embeddings (if provided).
GroupNorm_1
Second group normalization layer.
Dropout_0
Dropout layer.
Conv_1
Second convolutional layer.
NIN_0
Non-linear layer for adjusting input dimensions (if needed).
Conv_2
Second convolutional layer for shortcut connection (if needed).
- Parameters:
- act (callable) – Activation function to be used (e.g., ReLU).
- in_ch (int) – Number of input channels.
- out_ch (int , optional) – Number of output channels. Defaults to in_ch.
- temb_dim (int , optional) – Dimension of the temporal embedding. Defaults to None.
- conv_shortcut (bool , optional) – If True, use convolution for shortcut. Defaults to False.
- dropout (float , optional) – Dropout probability. Defaults to 0.1.
- skip_rescale (bool , optional) – If True, apply rescaling to the output. Defaults to False.
- init_scale (float , optional) – Scale for weight initialization. Defaults to 0.0.
- Returns: The output of the residual block after applying the operations.
- Return type: Tensor
####### Examples
>>> block = ResnetBlockDDPMpp(act=F.relu, in_ch=64, out_ch=128)
>>> x = torch.randn(1, 64, 32, 32)
>>> output = block(x)
>>> output.shape
torch.Size([1, 128, 32, 32])
NOTE
This block is designed to be used within diffusion models where residual learning can enhance performance.
- Raises:ValueError – If method in skip connections is not recognized.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, temb=None)
Performs a forward pass through the ResNet block.
This method applies a series of operations including normalization, convolution, and optional time embedding addition to the input tensor. The output can be either a residual connection or a rescaled output based on the configuration.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (B, C, H, W), where B is the batch size, C is the number of channels, H is the height, and W is the width.
- temb (torch.Tensor , optional) – Time embedding tensor of shape (B, T), where T is the dimensionality of the time embedding. If provided, it is added to the output feature maps.
- Returns: Output tensor of shape (B, out_ch, H, W), where : out_ch is the number of output channels. The output is computed as either a simple addition of the input and output features or a rescaled sum based on the skip_rescale attribute.
- Return type: torch.Tensor
####### Examples
>>> model = ResnetBlockDDPMpp(act=nn.ReLU(), in_ch=64, out_ch=128)
>>> input_tensor = torch.randn(8, 64, 32, 32) # Batch of 8
>>> output_tensor = model(input_tensor)
>>> output_tensor.shape
torch.Size([8, 128, 32, 32])
NOTE
If temb is provided, it must match the output channel dimension.
- Raises:ValueError – If the input tensor’s channel dimension does not match the expected dimensions based on the model’s configuration.