espnet2.gan_codec.soundstream.soundstream.SoundStreamDiscriminator
espnet2.gan_codec.soundstream.soundstream.SoundStreamDiscriminator
class espnet2.gan_codec.soundstream.soundstream.SoundStreamDiscriminator(scales: int = 3, scale_downsample_pooling: str = 'AvgPool1d', scale_downsample_pooling_params: Dict[str, Any] = {'kernel_size': 4, 'padding': 2, 'stride': 2}, scale_discriminator_params: Dict[str, Any] = {'bias': True, 'channels': 128, 'downsample_scales': [2, 2, 4, 4, 1], 'in_channels': 1, 'kernel_sizes': [15, 41, 5, 3], 'max_downsample_channels': 1024, 'max_groups': 16, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1}, scale_follow_official_norm: bool = False, complexstft_discriminator_params: Dict[str, Any] = {'chan_mults': [1, 2, 4, 4, 8, 8], 'channels': 32, 'hop_length': 256, 'in_channels': 1, 'n_fft': 1024, 'stft_normalized': False, 'strides': [[1, 2], [2, 2], [1, 2], [2, 2], [1, 2], [2, 2]], 'win_length': 1024})
Bases: Module
SoundStream discriminator module.
This module implements a multi-scale and complex STFT discriminator for the SoundStream model. It is designed to distinguish between real and generated audio signals, using multiple scales of feature extraction and complex short-time Fourier transform.
msd
Multi-scale discriminator component.
complex_stft_d
Complex STFT discriminator component.
Parameters:
- scales (int) – Number of multi-scales for the discriminator.
- scale_downsample_pooling (str) – Pooling module name for downsampling of the inputs.
- scale_downsample_pooling_params (Dict *[*str , Any ]) – Parameters for the above pooling module.
- scale_discriminator_params (Dict *[*str , Any ]) – Parameters for HiFi-GAN scale discriminator module.
- scale_follow_official_norm (bool) – Whether to follow the norm setting of the official implementation. The first discriminator uses spectral norm and the other discriminators use weight norm.
- complexstft_discriminator_params (Dict *[*str , Any ]) – Parameters for the complex STFT discriminator module.
####### Examples
>>> discriminator = SoundStreamDiscriminator(scales=3)
>>> input_tensor = torch.randn(8, 1, 16000) # Batch of 8 audio signals
>>> outputs = discriminator(input_tensor)
>>> len(outputs) # Outputs will be a list containing outputs from both
... # multi-scale and complex STFT discriminators.
Initialize SoundStream Discriminator module.
- Parameters:
- scales (int) – Number of multi-scales.
- sclae_downsample_pooling (str) – Pooling module name for downsampling of the inputs.
- scale_downsample_pooling_params (Dict *[*str , Any ]) – Parameters for the above pooling module.
- scale_discriminator_params (Dict *[*str , Any ]) – Parameters for hifi-gan scale discriminator module.
- scale_follow_official_norm (bool) – Whether to follow the norm setting of the official implementaion. The first discriminator uses spectral norm and the other discriminators use weight norm.
- complexstft_discriminator_params (Dict *[*str , Any ]) – Parameters for the complex stft discriminator module.
forward(x: Tensor) → List[List[Tensor]]
Perform forward propagation for the SoundStream model.
This method either runs the generator or discriminator depending on the value of the forward_generator flag. If forward_generator is set to True, the method will compute the generator’s output and loss; if set to False, it will compute the discriminator’s output and loss.
- Parameters:
- audio (torch.Tensor) – Audio waveform tensor of shape (B, T_wav), where B is the batch size and T_wav is the number of audio samples.
- forward_generator (bool) – A flag indicating whether to forward the generator (True) or the discriminator (False). Defaults to True.
- Returns:
- loss (Tensor): A scalar tensor representing the total loss.
- stats (Dict[str, float]): A dictionary containing various statistics to be monitored during training.
- weight (Tensor): A tensor summarizing the weights for loss computation.
- optim_idx (int): An integer indicating the optimizer index (0 for generator and 1 for discriminator).
- Return type: Dict[str, Any]
####### Examples
>>> audio_input = torch.randn(4, 16000) # Example audio input
>>> model = SoundStream() # Initialize the model
>>> output = model.forward(audio_input, forward_generator=True)
>>> print(output['loss']) # Access the computed loss
NOTE
Ensure that the audio tensor is properly shaped and normalized before passing it to the forward method.
- Raises:ValueError – If the audio tensor is not of shape (B, T_wav).