espnet2.gan_codec.shared.encoder.seanet_2d.SEANetResnetBlock2d
espnet2.gan_codec.shared.encoder.seanet_2d.SEANetResnetBlock2d
class espnet2.gan_codec.shared.encoder.seanet_2d.SEANetResnetBlock2d(dim: int, kernel_sizes: List[Tuple[int, int]] = [(3, 3), (1, 1)], dilations: List[Tuple[int, int]] = [(1, 1), (1, 1)], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, norm: str = 'weight_norm', norm_params: Dict[str, Any] = {}, causal: bool = False, pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True, conv_group_ratio: int = -1)
Bases: Module
Residual block from SEANet model.
This class implements a residual block that consists of convolutional layers, activation functions, and normalization techniques. The design allows for customizable parameters including kernel sizes, dilations, and whether to use causal convolutions or true skip connections.
block
A sequence of layers including activation functions and convolutions.
- Type: nn.Sequential
shortcut
A shortcut connection that can either be an identity mapping or a convolutional layer.
Type: nn.Module
Parameters:
- dim (int) – Dimension of the input/output.
- kernel_sizes (list) – List of kernel sizes for the convolutions.
- dilations (list) – List of dilations for the convolutions.
- activation (str) – Activation function to use.
- activation_params (dict) – Parameters for the activation function.
- norm (str) – Normalization method to apply.
- norm_params (dict) – Parameters for the underlying normalization used with the convolution.
- causal (bool) – Whether to use fully causal convolution.
- pad_mode (str) – Padding mode for the convolutions.
- compress (int) – Reduced dimensionality in residual branches.
- true_skip (bool) – Whether to use true skip connection or a simple convolution as the skip connection.
- conv_group_ratio (int) – Ratio for grouping convolutions.
####### Examples
>>> block = SEANetResnetBlock2d(
... dim=128,
... kernel_sizes=[(3, 3), (1, 1)],
... dilations=[(1, 1), (1, 1)],
... activation='ELU',
... activation_params={'alpha': 1.0},
... norm='weight_norm',
... norm_params={},
... causal=False,
... pad_mode='reflect',
... compress=2,
... true_skip=True
... )
>>> input_tensor = torch.randn(8, 128, 32, 32) # Batch size of 8
>>> output_tensor = block(input_tensor)
>>> output_tensor.shape
torch.Size([8, 128, 32, 32]) # Output shape matches input shape
NOTE
The block assumes that the input tensor has the shape (batch_size, channels, height, width).
- Raises:AssertionError – If the number of kernel sizes does not match the number of dilations.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Performs the forward pass of the SEANetResnetBlock2d.
The method computes the output of the residual block by applying the convolutional layers and the shortcut connection. It adds the output of the convolutional block to the output of the shortcut to form the residual connection.
- Parameters:x (torch.Tensor) – Input tensor of shape (B, C, F, T) where:
- B is the batch size,
- C is the number of channels,
- F is the frequency dimension,
- T is the time dimension.
- Returns: Output tensor of the same shape as input tensor x, which is the sum of the shortcut output and the block output.
- Return type: torch.Tensor
####### Examples
>>> block = SEANetResnetBlock2d(dim=128)
>>> input_tensor = torch.randn(10, 128, 32, 32) # Batch of 10
>>> output_tensor = block(input_tensor)
>>> output_tensor.shape
torch.Size([10, 128, 32, 32])
NOTE
The input tensor must have exactly 4 dimensions. If the input tensor has fewer dimensions, an assertion error will be raised.
- Raises:AssertionError – If the input tensor does not have 4 dimensions.