espnet2.gan_tts.style_melgan.style_melgan.StyleMelGANDiscriminator
espnet2.gan_tts.style_melgan.style_melgan.StyleMelGANDiscriminator
class espnet2.gan_tts.style_melgan.style_melgan.StyleMelGANDiscriminator(repeats: int = 2, window_sizes: List[int] = [512, 1024, 2048, 4096], pqmf_params: List[List[int]] = [[1, None, None, None], [2, 62, 0.267, 9.0], [4, 62, 0.142, 9.0], [8, 62, 0.07949, 9.0]], discriminator_params: Dict[str, Any] = {'bias': True, 'channels': 16, 'downsample_scales': [4, 4, 4, 1], 'kernel_sizes': [5, 3], 'max_downsample_channels': 512, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.2}, 'out_channels': 1, 'pad': 'ReflectionPad1d', 'pad_params': {}}, use_weight_norm: bool = True)
Bases: Module
Style MelGAN discriminator module.
This module serves as a discriminator for the Style MelGAN architecture, which is designed to evaluate the quality of generated audio signals. It uses a combination of PQMF (Polyphase Quadrature Mirror Filter) and a base discriminator to assess the authenticity of audio samples.
repeats
Number of repetitions to apply Random Window Discrimination (RWD).
- Type: int
window_sizes
List of random window sizes for analysis.
- Type: List[int]
pqmfs
List of PQMF modules for downsampling.
- Type: ModuleList
discriminators
List of base discriminators.
Type: ModuleList
Parameters:
- repeats (int) – Number of repetitions to apply RWD.
- window_sizes (List *[*int ]) – List of random window sizes.
- pqmf_params (List *[*List *[*int ] ]) – Parameters for PQMF modules.
- discriminator_params (Dict *[*str , Any ]) – Parameters for the base discriminator.
- use_weight_norm (bool) – Whether to apply weight normalization.
Raises:AssertionError – If the lengths of window_sizes and pqmf_params do not match or if the sum of calculated sizes does not match the length of window_sizes.
########### Examples
>>> discriminator = StyleMelGANDiscriminator()
>>> input_tensor = torch.randn(8, 1, 2048) # Batch of 8 samples
>>> outputs = discriminator(input_tensor)
>>> print(len(outputs)) # Output will be the number of discriminators * repeats
####### NOTE The pqmf_params should be defined carefully to match the window sizes.
Initilize StyleMelGANDiscriminator module.
- Parameters:
- repeats (int) – Number of repititons to apply RWD.
- window_sizes (List *[*int ]) – List of random window sizes.
- pqmf_params (List *[*List *[*int ] ]) – List of list of Parameters for PQMF modules
- discriminator_params (Dict *[*str , Any ]) – Parameters for base discriminator module.
- use_weight_nom (bool) – Whether to apply weight normalization.
apply_weight_norm()
Apply weight normalization module from all of the layers.
This method applies weight normalization to all convolutional layers (both 1D convolution and transposed convolution) in the network. Weight normalization can help in stabilizing the training process and can lead to better convergence.
It iterates through all the modules in the network and applies weight normalization if the module is an instance of torch.nn.Conv1d or torch.nn.ConvTranspose1d.
####### NOTE This function is typically called during the initialization of the model if the use_weight_norm flag is set to True.
########### Examples
>>> model = StyleMelGANDiscriminator(use_weight_norm=True)
>>> model.apply_weight_norm() # Applies weight normalization to all layers
forward(x: Tensor) → List[Tensor]
Calculate forward propagation.
This method processes the input tensor through multiple discriminators and returns their outputs. It applies random windowing to the input for each discriminator to evaluate the audio features.
- Parameters:x (Tensor) – Input tensor of shape (B, 1, T), where B is the batch size, and T is the length of the audio signal.
- Returns: A list of discriminator outputs. The number of items in the : list will be equal to repeats * number of discriminators.
- Return type: List
########### Examples
>>> discriminator = StyleMelGANDiscriminator()
>>> input_tensor = torch.randn(8, 1, 4096) # Batch of 8 samples
>>> outputs = discriminator(input_tensor)
>>> print(len(outputs)) # Should equal repeats * number of discriminators
reset_parameters()
Reset parameters of the discriminator’s convolutional layers.
This method iterates through all the modules in the discriminator and resets the weights of the convolutional layers (both Conv1d and ConvTranspose1d) to follow a normal distribution with mean 0.0 and standard deviation 0.02. This is commonly used to initialize weights before training a neural network to ensure better convergence.
It also logs a debug message indicating that the parameters have been reset for each convolutional layer.
########### Examples
>>> discriminator = StyleMelGANDiscriminator()
>>> discriminator.reset_parameters()
####### NOTE This method is typically called during the initialization of the model to ensure that the parameters are set correctly before training begins.