espnet2.gan_codec.hificodec.hificodec.HiFiCodecDiscriminator
espnet2.gan_codec.hificodec.hificodec.HiFiCodecDiscriminator
class espnet2.gan_codec.hificodec.hificodec.HiFiCodecDiscriminator(msstft_discriminator_params: Dict[str, Any] = {'activation': 'LeakyReLU', 'activation_params': {'negative_slope': 0.2}, 'filters': 32, 'hop_lengths': [256, 512, 128, 64, 32], 'in_channels': 1, 'n_ffts': [1024, 2048, 512, 256, 128], 'norm': 'weight_norm', 'out_channels': 1, 'win_lengths': [1024, 2048, 512, 256, 128]}, scales: int = 3, scale_downsample_pooling: str = 'AvgPool1d', scale_downsample_pooling_params: Dict[str, Any] = {'kernel_size': 4, 'padding': 2, 'stride': 2}, scale_discriminator_params: Dict[str, Any] = {'bias': False, 'channels': 128, 'downsample_scales': [2, 2, 4, 4, 1], 'in_channels': 1, 'kernel_sizes': [15, 41, 5, 3], 'max_downsample_channels': 1024, 'max_groups': 16, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1, 'use_spectral_norm': False, 'use_weight_norm': True}, scale_follow_official_norm: bool = False, periods: List[int] = [2, 3, 5, 7, 11], periods_discriminator_params: Dict[str, Any] = {'bias': False, 'channels': 32, 'downsample_scales': [3, 3, 3, 3, 1], 'in_channels': 1, 'kernel_sizes': [5, 3], 'max_downsample_channels': 1024, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1, 'use_spectral_norm': False, 'use_weight_norm': True})
Bases: Module
HiFiCodec discriminator module.
This class implements a discriminator for the HiFiCodec model, which uses multi-scale and multi-period discriminators to evaluate the quality of generated audio signals.
msstft
Multi-scale STFT discriminator.
msd
Multi-scale discriminator.
mpd
Multi-period discriminator.
Parameters:
- msstft_discriminator_params (Dict *[*str , Any ]) – Parameters for multi-scales STFT discriminator module.
- scales (int) – Number of multi-scales.
- scale_downsample_pooling (str) – Pooling module name for downsampling of the inputs.
- scale_downsample_pooling_params (Dict *[*str , Any ]) – Parameters for the above pooling module.
- scale_discriminator_params (Dict *[*str , Any ]) – Parameters for HiFi-GAN scale discriminator module.
- scale_follow_official_norm (bool) – Flag to follow the official normalization.
- periods (List *[*int ]) – List of periods for multi-period discriminator.
- periods_discriminator_params (Dict *[*str , Any ]) – Parameters for HiFi-GAN period discriminator module. The period parameter will be overwritten.
####### Examples
>>> discriminator = HiFiCodecDiscriminator()
>>> input_tensor = torch.randn(8, 1, 16000) # Batch of 8 audio signals
>>> outputs = discriminator(input_tensor)
>>> len(outputs) # Check the number of outputs from the discriminators
8
Initialize HiFiCodec Discriminator module.
- Parameters:
- msstft_discriminator_params (Dict *[*str , Any ]) – Parameters for multi-scales STFT discriminator module.
- scales (int) – Number of multi-scales.
- sclae_downsample_pooling (str) – Pooling module name for downsampling of the inputs.
- scale_downsample_pooling_params (Dict *[*str , Any ]) – Parameters for the above pooling module.
- scale_discriminator_params (Dict *[*str , Any ]) – Parameters for hifi-gan scale discriminator module.
- periods (List *[*int ]) – List of periods.
- discriminator_params (Dict *[*str , Any ]) – Parameters for hifi-gan period discriminator module. The period parameter will be overwritten.
forward(x: Tensor) → List[List[Tensor]]
Perform forward propagation for the HiFiCodec model.
This method decides whether to forward the audio through the generator or the discriminator based on the forward_generator flag. It computes the loss and statistics based on the selected path.
- Parameters:
- audio (Tensor) – Audio waveform tensor of shape (B, T_wav), where B is the batch size and T_wav is the length of the audio waveform.
- forward_generator (bool) – If True, forwards the audio through the generator; if False, forwards it through the discriminator.
- Returns:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored during training.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Index for the optimizer (0 for generator and 1 for discriminator).
- Return type: Dict[str, Any]
####### Examples
>>> model = HiFiCodec()
>>> audio_input = torch.randn(8, 16000) # Batch of 8 audio samples
>>> output = model.forward(audio_input, forward_generator=True)
>>> print(output['loss'], output['stats'])
NOTE
Ensure that the audio input tensor is properly shaped and contains valid waveform data. The kwargs can be used to pass additional parameters needed for either generator or discriminator forward pass.