espnet2.gan_codec.hificodec.module.Generator
espnet2.gan_codec.hificodec.module.Generator
class espnet2.gan_codec.hificodec.module.Generator(upsample_rates, upsample_kernel_sizes, upsample_initial_channel, resblock_num, resblock_kernel_sizes, resblock_dilation_sizes, out_dim)
Bases: Module
Generator module for HiFi-GAN based audio synthesis.
This class implements a generator that upsamples and processes input audio signals using a series of convolutional layers and residual blocks. It is designed for use in GAN-based audio codec applications.
num_kernels
Number of kernel sizes used in residual blocks.
- Type: int
num_upsamples
Number of upsampling layers.
- Type: int
conv_pre
Initial convolutional layer.
- Type: nn.Module
ups
List of upsampling layers.
- Type: nn.ModuleList
resblocks
List of residual blocks.
- Type: nn.ModuleList
conv_post
Final convolutional layer.
Type: nn.Module
Parameters:
- upsample_rates (List *[*int ]) – List of upsampling rates.
- upsample_kernel_sizes (List *[*int ]) – List of kernel sizes for upsampling.
- upsample_initial_channel (int) – Number of channels in the initial layer.
- resblock_num (str) – Type of residual block to use (‘1’ or ‘2’).
- resblock_kernel_sizes (List *[*int ]) – List of kernel sizes for residual blocks.
- resblock_dilation_sizes (List *[*List *[*int ] ]) – List of dilation sizes for residual blocks.
- out_dim (int) – Dimension of the output signal.
Returns: Synthesized audio signal.
Return type: torch.Tensor
######### Examples
>>> generator = Generator(
... upsample_rates=[8, 8, 2],
... upsample_kernel_sizes=[16, 16, 4],
... upsample_initial_channel=256,
... resblock_num='2',
... resblock_kernel_sizes=[3, 5, 7],
... resblock_dilation_sizes=[[1, 2], [1, 3]],
... out_dim=1
... )
>>> input_signal = torch.randn(1, 1, 256)
>>> output_signal = generator(input_signal)
>>> output_signal.shape
torch.Size([1, 1, 256])
- Raises:ValueError – If resblock_num is not ‘1’ or ‘2’.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Passes the input tensor through the generator network.
This method applies a series of convolutional and transposed convolutional layers to the input tensor x, which is expected to have shape (B, C, T), where B is the batch size, C is the number of input channels, and T is the length of the input sequence. The output tensor is generated by applying leaky ReLU activations and a series of residual blocks.
- Parameters:x (torch.Tensor) – Input tensor of shape (B, C, T).
- Returns: Output tensor of shape (B, 1, T_out), where T_out is the length of the output sequence after all layers have been applied.
- Return type: torch.Tensor
######### Examples
>>> generator = Generator(upsample_rates=[4, 4, 4],
... upsample_kernel_sizes=[8, 8, 8],
... upsample_initial_channel=256,
... resblock_num="2",
... resblock_kernel_sizes=[3, 5, 7],
... resblock_dilation_sizes=[[1, 2, 4],
... [1, 2, 4]],
... out_dim=80)
>>> input_tensor = torch.randn(1, 80, 100) # Example input
>>> output_tensor = generator(input_tensor)
>>> output_tensor.shape
torch.Size([1, 1, T_out]) # T_out will depend on the generator config
NOTE
Ensure that the input tensor x has the correct shape and number of channels as expected by the generator.
- Raises:
- ValueError – If the input tensor does not have the expected number of
- dimensions or shape. –
remove_weight_norm()
Remove weight normalization from the layers of the generator.
This method removes weight normalization from all layers within the generator, including the upsampling layers, residual blocks, and the convolutional layers. It is important to call this method when the model is being prepared for inference or after training, as it ensures that the model behaves as expected without the added complexity of weight normalization.
NOTE
This method will print a message indicating that weight normalization is being removed.
######### Examples
>>> generator = Generator(...) # Initialize the generator
>>> generator.remove_weight_norm() # Remove weight normalization
- Raises:
- RuntimeError – If any layer does not support weight normalization
- removal. –