espnet2.gan_codec.dac.dac.DACDiscriminator
espnet2.gan_codec.dac.dac.DACDiscriminator
class espnet2.gan_codec.dac.dac.DACDiscriminator(msmpmb_discriminator_params: Dict[str, Any] = {'band_discriminator_params': {'bands': [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)], 'channel': 32, 'hop_factor': 0.25, 'sample_rate': 24000}, 'fft_sizes': [2048, 1024, 512], 'period_discriminator_params': {'bias': True, 'channels': 32, 'downsample_scales': [3, 3, 3, 3, 1], 'in_channels': 1, 'kernel_sizes': [5, 3], 'max_downsample_channels': 1024, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1, 'use_spectral_norm': False, 'use_weight_norm': True}, 'periods': [2, 3, 5, 7, 11], 'rates': [], 'sample_rate': 24000}, scale_follow_official_norm: bool = False)
Bases: Module
DAC discriminator module.
This class implements the DAC Discriminator, which is responsible for distinguishing between real and generated audio signals. It utilizes a MultiScaleMultiPeriodMultiBand Discriminator architecture to process audio inputs at various scales and periods.
- Parameters:
- msmpmb_discriminator_params (Dict *[*str , Any ]) – Parameters for the MultiScaleMultiPeriodMultiBandDiscriminator. This includes settings for rates, periods, FFT sizes, and other relevant configurations for the period and band discriminators.
- scale_follow_official_norm (bool) – If True, applies official normalization scale during the discriminator’s processing.
msmpmb_discriminator
An instance of the MultiScaleMultiPeriodMultiBandDiscriminator configured with the provided parameters.
Returns: The output of the discriminator, which is a list of lists containing the outputs from each layer of the discriminator. Each list corresponds to a specific scale and period output.
Return type: List[List[torch.Tensor]]
####### Examples
>>> discriminator = DACDiscriminator()
>>> input_tensor = torch.randn(1, 1, 1024) # Example input
>>> outputs = discriminator(input_tensor)
>>> print(len(outputs)) # Output length corresponds to number of scales
NOTE
The DAC Discriminator is a critical component in the DAC model’s adversarial training process, enabling the generator to learn more effectively by providing feedback on the quality of generated audio.
Initialize DAC Discriminator module.
Args:
forward(x: Tensor) → List[List[Tensor]]
Perform generator or discriminator forward pass.
This method directs the input audio tensor to either the generator or the discriminator based on the forward_generator flag. If set to True, it forwards the audio to the generator, otherwise to the discriminator.
- Parameters:
- audio (Tensor) – Audio waveform tensor of shape (B, T_wav), where B is the batch size and T_wav is the number of time steps.
- forward_generator (bool) – Flag indicating whether to forward the audio through the generator (True) or the discriminator (False).
- Returns: A dictionary containing the following keys: : - loss (Tensor): Loss scalar tensor computed during the forward pass.
- stats (Dict[str, float]): Statistics computed during the forward pass for monitoring.
- weight (Tensor): Weight tensor used to summarize losses.
- optim_idx (int): Index indicating which optimizer to use (0 for generator and 1 for discriminator).
- Return type: Dict[str, Any]
####### Examples
>>> model = DAC()
>>> audio_input = torch.randn(8, 16000) # Batch of 8, 16000 samples
>>> output = model.forward(audio_input, forward_generator=True)
>>> print(output.keys())
dict_keys(['loss', 'stats', 'weight', 'optim_idx'])
NOTE
This method will internally call either _forward_generator or _forward_discrminator based on the value of forward_generator.