espnet2.gan_codec.funcodec.funcodec.FunCodecDiscriminator
espnet2.gan_codec.funcodec.funcodec.FunCodecDiscriminator
class espnet2.gan_codec.funcodec.funcodec.FunCodecDiscriminator(scales: int = 3, scale_downsample_pooling: str = 'AvgPool1d', scale_downsample_pooling_params: Dict[str, Any] = {'kernel_size': 4, 'padding': 2, 'stride': 2}, scale_discriminator_params: Dict[str, Any] = {'bias': True, 'channels': 128, 'downsample_scales': [2, 2, 4, 4, 1], 'in_channels': 1, 'kernel_sizes': [15, 41, 5, 3], 'max_downsample_channels': 1024, 'max_groups': 16, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1}, scale_follow_official_norm: bool = False, periods: List[int] = [2, 3, 5, 7, 11], period_discriminator_params: Dict[str, Any] = {'bias': True, 'channels': 32, 'downsample_scales': [3, 3, 3, 3, 1], 'in_channels': 1, 'kernel_sizes': [5, 3], 'max_downsample_channels': 1024, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1, 'use_spectral_norm': False, 'use_weight_norm': True}, complexstft_discriminator_params: Dict[str, Any] = {'chan_mults': [1, 2, 4, 4, 8, 8], 'channels': 32, 'hop_length': 256, 'in_channels': 1, 'n_fft': 1024, 'stft_normalized': False, 'strides': [[1, 2], [2, 2], [1, 2], [2, 2], [1, 2], [2, 2]], 'win_length': 1024})
Bases: Module
FunCodec discriminator module.
This class implements a multi-scale and multi-period discriminator for the FunCodec model. It utilizes various discriminators including a multi-scale discriminator, a multi-period discriminator, and a complex STFT discriminator.
msd
Multi-scale discriminator.
mpd
Multi-period discriminator.
complex_stft_d
Complex STFT discriminator.
Parameters:
- scales (int) – Number of multi-scales.
- scale_downsample_pooling (str) – Pooling module name for downsampling of the inputs.
- scale_downsample_pooling_params (Dict *[*str , Any ]) – Parameters for the above pooling module.
- scale_discriminator_params (Dict *[*str , Any ]) – Parameters for hifi-gan scale discriminator module.
- scale_follow_official_norm (bool) – Whether to follow the norm setting of the official implementation. The first discriminator uses spectral norm and the other discriminators use weight norm.
- periods (List *[*int ]) – List of periods for the multi-period discriminator.
- period_discriminator_params (Dict *[*str , Any ]) – Parameters for the multi-period discriminator.
- complexstft_discriminator_params (Dict *[*str , Any ]) – Parameters for the complex STFT discriminator module.
####### Examples
>>> discriminator = FunCodecDiscriminator()
>>> input_tensor = torch.randn(1, 1, 256) # Example input
>>> outputs = discriminator(input_tensor)
>>> print(len(outputs)) # Number of outputs from the discriminators
Initialize FunCodec Discriminator module.
- Parameters:
- scales (int) – Number of multi-scales.
- sclae_downsample_pooling (str) – Pooling module name for downsampling of the inputs.
- scale_downsample_pooling_params (Dict *[*str , Any ]) – Parameters for the above pooling module.
- scale_discriminator_params (Dict *[*str , Any ]) – Parameters for hifi-gan scale discriminator module.
- scale_follow_official_norm (bool) – Whether to follow the norm setting of the official implementaion. The first discriminator uses spectral norm and the other discriminators use weight norm.
- complexstft_discriminator_params (Dict *[*str , Any ]) – Parameters for the complex stft discriminator module.
forward(x: Tensor) → List[List[Tensor]]
Perform forward propagation through the model.
This method executes the forward pass of the FunCodec model. Depending on the forward_generator flag, it either computes the generator’s output or the discriminator’s output. It returns a dictionary containing the loss, statistics, and other relevant information.
- Parameters:
- audio (torch.Tensor) – Audio waveform tensor of shape (B, T_wav), where B is the batch size and T_wav is the length of the audio.
- forward_generator (bool) – Flag to determine whether to forward through the generator (True) or the discriminator (False). Defaults to True.
- Returns: A dictionary containing: : - loss (Tensor): Loss scalar tensor computed during the forward pass.
- stats (Dict[str, float]): Statistics for monitoring performance, including various loss components.
- weight (Tensor): Weight tensor used for summarizing losses.
- optim_idx (int): Index indicating which optimizer to use (0 for generator and 1 for discriminator).
- Return type: Dict[str, Any]
####### Examples
>>> model = FunCodec()
>>> audio_input = torch.randn(2, 16000) # Batch of 2 audio samples
>>> output = model.forward(audio_input)
>>> print(output['loss']) # Access the computed loss
NOTE
Ensure that the input audio tensor is correctly shaped and normalized.