espnet2.gan_codec.shared.encoder.seanet.SEANetResnetBlock
espnet2.gan_codec.shared.encoder.seanet.SEANetResnetBlock
class espnet2.gan_codec.shared.encoder.seanet.SEANetResnetBlock(dim: int, kernel_sizes: List[int] = [3, 1], dilations: List[int] = [1, 1], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, norm: str = 'weight_norm', norm_params: Dict[str, Any] = {}, causal: bool = False, pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True)
Bases: Module
Residual block from SEANet model.
This class implements a residual block for the SEANet architecture, which consists of a series of convolutional layers, activation functions, and normalization layers. It is designed to facilitate efficient training of deep networks while preserving important features through skip connections.
- Parameters:
- dim (int) – Dimension of the input/output.
- kernel_sizes (list) – List of kernel sizes for the convolutions.
- dilations (list) – List of dilations for the convolutions.
- activation (str) – Activation function.
- activation_params (dict) – Parameters to provide to the activation function.
- norm (str) – Normalization method.
- norm_params (dict) – Parameters to provide to the underlying normalization used along with the convolution.
- causal (bool) – Whether to use fully causal convolution.
- pad_mode (str) – Padding mode for the convolutions.
- compress (int) – Reduced dimensionality in residual branches (from Demucs v3).
- true_skip (bool) – Whether to use true skip connection or a simple convolution as the skip connection.
####### Examples
>>> block = SEANetResnetBlock(dim=128, kernel_sizes=[3, 1],
... dilations=[1, 1], activation='ELU',
... activation_params={'alpha': 1.0},
... norm='weight_norm',
... norm_params={}, causal=False,
... pad_mode='reflect', compress=2,
... true_skip=True)
>>> x = torch.randn(10, 128, 50) # Batch of 10, 128 channels, 50 time steps
>>> output = block(x)
>>> output.shape
torch.Size([10, 128, 50])
NOTE
The block utilizes skip connections to improve gradient flow and reduce the risk of vanishing gradients in deeper networks.
- Raises:
- AssertionError – If the number of kernel sizes does not match the
- number of dilations. –
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Residual block from SEANet model.
This class implements a residual block used in the SEANet model, which consists of convolutional layers with normalization and activation functions. It allows for causal convolutions and provides options for different normalization methods.
block
A sequential container of activation and convolutional layers.
- Type: nn.Sequential
shortcut
A skip connection that can be either an identity mapping or a convolutional layer.
Type: nn.Module
Parameters:
- dim (int) – Dimension of the input/output.
- kernel_sizes (list) – List of kernel sizes for the convolutions.
- dilations (list) – List of dilations for the convolutions.
- activation (str) – Activation function.
- activation_params (dict) – Parameters to provide to the activation function.
- norm (str) – Normalization method.
- norm_params (dict) – Parameters to provide to the underlying normalization used along with the convolution.
- causal (bool) – Whether to use fully causal convolution.
- pad_mode (str) – Padding mode for the convolutions.
- compress (int) – Reduced dimensionality in residual branches (from Demucs v3).
- true_skip (bool) – Whether to use true skip connection or a simple convolution as the skip connection.
####### Examples
>>> block = SEANetResnetBlock(dim=128)
>>> input_tensor = torch.randn(1, 128, 64) # (batch, channels, time)
>>> output_tensor = block(input_tensor)
>>> output_tensor.shape
torch.Size([1, 128, 64])
- Raises:AssertionError – If the number of kernel sizes does not match the number of dilations.
NOTE
This block is designed to work with various normalization methods, including weight normalization and layer normalization. It can be easily adapted to different activation functions.