espnet2.gan_tts.melgan.melgan.MelGANMultiScaleDiscriminator
espnet2.gan_tts.melgan.melgan.MelGANMultiScaleDiscriminator
class espnet2.gan_tts.melgan.melgan.MelGANMultiScaleDiscriminator(in_channels: int = 1, out_channels: int = 1, scales: int = 3, downsample_pooling: str = 'AvgPool1d', downsample_pooling_params: Dict[str, Any] = {'count_include_pad': False, 'kernel_size': 4, 'padding': 1, 'stride': 2}, kernel_sizes: List[int] = [5, 3], channels: int = 16, max_downsample_channels: int = 1024, bias: bool = True, downsample_scales: List[int] = [4, 4, 4, 4], nonlinear_activation: str = 'LeakyReLU', nonlinear_activation_params: Dict[str, Any] = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d', pad_params: Dict[str, Any] = {}, use_weight_norm: bool = True)
Bases: Module
MelGAN multi-scale discriminator module.
This class implements a multi-scale discriminator for the MelGAN architecture, allowing the model to evaluate audio signals at different resolutions. The discriminator consists of multiple MelGANDiscriminator instances that process the input through a specified downsampling pooling operation.
discriminators
A list of MelGANDiscriminator instances, one for each scale.
- Type: torch.nn.ModuleList
pooling
The pooling layer used for downsampling the input signal.
Type: torch.nn.Module
Parameters:
- in_channels (int) – Number of input channels.
- out_channels (int) – Number of output channels.
- scales (int) – Number of multi-scales.
- downsample_pooling (str) – Pooling module name for downsampling of the inputs.
- downsample_pooling_params (Dict *[*str , Any ]) – Parameters for the above pooling module.
- kernel_sizes (List *[*int ]) – List of two kernel sizes. The sum will be used for the first conv layer, and the first and the second kernel sizes will be used for the last two layers.
- channels (int) – Initial number of channels for conv layer.
- max_downsample_channels (int) – Maximum number of channels for downsampling layers.
- bias (bool) – Whether to add bias parameter in convolution layers.
- downsample_scales (List *[*int ]) – List of downsampling scales.
- nonlinear_activation (str) – Activation function module name.
- nonlinear_activation_params (Dict *[*str , Any ]) – Hyperparameters for activation function.
- pad (str) – Padding function module name before dilated convolution layer.
- pad_params (Dict *[*str , Any ]) – Hyperparameters for padding function.
- use_weight_norm (bool) – Whether to use weight norm.
########
Example
>>> discriminator = MelGANMultiScaleDiscriminator(in_channels=1,
... out_channels=1, scales=3)
>>> input_tensor = torch.randn(8, 1, 1024) # Batch of 8, 1 channel, 1024 length
>>> outputs = discriminator(input_tensor)
>>> len(outputs) # Number of scales
3
- Returns: List of lists containing the outputs from each : discriminator for each scale.
- Return type: List[List[Tensor]]
- Raises:AssertionError – If any of the hyperparameters are invalid.
####### NOTE This class follows the official implementation manner for initializing parameters as described in the original MelGAN repository.
Initilize MelGANMultiScaleDiscriminator module.
- Parameters:
- in_channels (int) – Number of input channels.
- out_channels (int) – Number of output channels.
- scales (int) – Number of multi-scales.
- downsample_pooling (str) – Pooling module name for downsampling of the inputs.
- downsample_pooling_params (Dict *[*str , Any ]) – Parameters for the above pooling module.
- kernel_sizes (List *[*int ]) – List of two kernel sizes. The sum will be used for the first conv layer, and the first and the second kernel sizes will be used for the last two layers.
- channels (int) – Initial number of channels for conv layer.
- max_downsample_channels (int) – Maximum number of channels for downsampling layers.
- bias (bool) – Whether to add bias parameter in convolution layers.
- downsample_scales (List *[*int ]) – List of downsampling scales.
- nonlinear_activation (str) – Activation function module name.
- nonlinear_activation_params (Dict *[*str , Any ]) – Hyperparameters for activation function.
- pad (str) – Padding function module name before dilated convolution layer.
- pad_params (Dict *[*str , Any ]) – Hyperparameters for padding function.
- use_weight_norm (bool) – Whether to use weight norm.
apply_weight_norm()
Apply weight normalization module from all of the layers.
This method iterates through all the layers of the MelGANMultiScaleDiscriminator and applies weight normalization to each convolutional layer. Weight normalization can help stabilize the training of neural networks by reparameterizing the weight vectors.
It specifically targets layers of type torch.nn.Conv1d and torch.nn.ConvTranspose1d, applying the weight normalization technique from torch.nn.utils.
########
Example
>>> discriminator = MelGANMultiScaleDiscriminator()
>>> discriminator.apply_weight_norm() # Applies weight normalization
####### NOTE This function is usually called during the initialization of the model if use_weight_norm is set to True.
forward(x: Tensor) → List[List[Tensor]]
Calculate forward propagation.
This method takes an input tensor and passes it through the multi-scale discriminators, returning the outputs of each layer for all discriminators.
- Parameters:x (Tensor) – Input noise signal (B, 1, T), where B is the batch size, 1 is the number of input channels, and T is the length of the input signal.
- Returns: A list of lists, where each inner list contains : the output tensors from each layer of a discriminator. The outer list corresponds to the outputs from each discriminator in the multi-scale setup.
- Return type: List[List[Tensor]]
########
Example
>>> discriminator = MelGANMultiScaleDiscriminator()
>>> input_tensor = torch.randn(4, 1, 1024) # Batch of 4, 1 channel, T=1024
>>> outputs = discriminator(input_tensor)
>>> len(outputs) # Should be equal to the number of scales
3
>>> len(outputs[0]) # Should match the number of layers in the discriminator
5 # For example, if each discriminator has 5 layers
remove_weight_norm()
Remove weight normalization module from all of the layers.
This method iterates through all layers of the MelGAN multi-scale discriminator and removes the weight normalization applied to the convolutional layers. If a layer does not have weight normalization, it will be skipped without raising an error.
Example
>>> discriminator = MelGANMultiScaleDiscriminator()
>>> discriminator.apply_weight_norm() # Apply weight normalization
>>> discriminator.remove_weight_norm() # Remove weight normalization
####### NOTE This method is useful when switching between training and evaluation modes where weight normalization might be needed only during training.
reset_parameters()
Reset parameters of the model.
This method reinitializes the weights of the convolutional layers in the MelGAN model according to the official implementation guidelines. It sets the weights of each convolutional layer to a normal distribution with a mean of 0 and a standard deviation of 0.02.
This follows the initialization method specified in the official MelGAN implementation: https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py
It is important to call this method to ensure that the model parameters are in a known state, especially after loading a pre-trained model or modifying the architecture.
########
Example
>>> model = MelGANMultiScaleDiscriminator()
>>> model.reset_parameters() # Resets all parameters in the model