espnet2.gan_codec.shared.discriminator.msmpmb_discriminator.MultiScaleMultiPeriodMultiBandDiscriminator
espnet2.gan_codec.shared.discriminator.msmpmb_discriminator.MultiScaleMultiPeriodMultiBandDiscriminator
class espnet2.gan_codec.shared.discriminator.msmpmb_discriminator.MultiScaleMultiPeriodMultiBandDiscriminator(rates: list = [], fft_sizes: list = [2048, 1024, 512], sample_rate: int = 44100, periods: List[int] = [2, 3, 5, 7, 11], period_discriminator_params: Dict[str, Any] = {'bias': True, 'channels': 32, 'downsample_scales': [3, 3, 3, 3, 1], 'in_channels': 1, 'kernel_sizes': [5, 3], 'max_downsample_channels': 1024, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1, 'use_spectral_norm': False, 'use_weight_norm': True}, band_discriminator_params: Dict[str, Any] = {'bands': [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)], 'channel': 32, 'hop_factor': 0.25, 'sample_rate': 24000})
Bases: Module
MultiScaleMultiPeriodMultiBandDiscriminator combines multiple discriminators
for audio signal processing, including multi-scale, multi-period, and multi-band analysis. This class utilizes various discriminators to enhance the performance of generative adversarial networks (GANs) for audio tasks.
rates
List of sampling rates (in Hz) to run the multi-scale discriminator (MSD) at. If empty, MSD is not used.
- Type: list
fft_sizes
Window sizes of the FFT to run the multi-band discriminator (MRD) at. Defaults to [2048, 1024, 512].
- Type: list
sample_rate
Sampling rate of audio in Hz. Defaults to 44100.
- Type: int
periods
List of periods (of samples) to run the multi-period discriminator (MPD) at. Defaults to [2, 3, 5, 7, 11].
- Type: list
period_discriminator_params
Parameters for the multi-period discriminator.
- Type: Dict[str, Any]
band_discriminator_params
Parameters for the multi-band discriminator.
Type: Dict[str, Any]
Parameters:
- rates (list , optional) – Sampling rates to run MSD at. Defaults to [].
- periods (list , optional) – Periods to run MPD at. Defaults to [2, 3, 5, 7, 11].
- fft_sizes (list , optional) – Window sizes of the FFT to run MRD at. Defaults to [2048, 1024, 512].
- sample_rate (int , optional) – Sampling rate of audio in Hz. Defaults to 44100.
- band_discriminator_params (Dict *[*str , Any ] , optional) – Parameters for the multi-band discriminator. Defaults to a specified dictionary.
######### Examples
>>> discriminator = MultiScaleMultiPeriodMultiBandDiscriminator(
... rates=[16000, 24000],
... periods=[2, 3],
... fft_sizes=[2048, 1024]
... )
>>> output = discriminator(torch.randn(1, 1, 44100)) # Example input
>>> print(len(output)) # Check the number of outputs
NOTE
The input audio signal should be of shape (batch_size, channels, samples).
- Raises:ValueError – If any of the parameters are invalid.
Discriminator that combines multiple discriminators.
- Parameters:
- rates (list , optional) – sampling rates (in Hz) to run MSD at, by default [] If empty, MSD is not used.
- periods (list , optional) – periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
- fft_sizes (list , optional) – Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
- sample_rate (int , optional) – Sampling rate of audio in Hz, by default 44100
- bands (list , optional) – Bands to run MRD at, by default BANDS
forward(x)
Discriminator that combines multiple discriminators.
This class integrates multiple discriminators, including multi-scale, multi-period, and multi-band discriminators to evaluate the quality of generated audio signals.
rates
Sampling rates (in Hz) to run the multi-scale discriminator. If empty, the multi-scale discriminator is not used.
- Type: list
fft_sizes
Window sizes of the FFT to run the multi-band discriminator. Default is [2048, 1024, 512].
- Type: list
sample_rate
Sampling rate of audio in Hz. Default is 44100.
- Type: int
periods
Periods (in samples) to run the multi-period discriminator. Default is [2, 3, 5, 7, 11].
- Type: list
period_discriminator_params
Parameters for the multi-period discriminator.
- Type: Dict[str, Any]
band_discriminator_params
Parameters for the multi-band discriminator.
Type: Dict[str, Any]
Parameters:
- rates – List of sampling rates for the multi-scale discriminator.
- fft_sizes – List of FFT window sizes for the multi-band discriminator.
- sample_rate – Sampling rate of the audio.
- periods – List of periods for the multi-period discriminator.
- period_discriminator_params – Parameters for the multi-period discriminator.
- band_discriminator_params – Parameters for the multi-band discriminator.
######### Examples
>>> discriminator = MultiScaleMultiPeriodMultiBandDiscriminator(
... rates=[1, 2, 4],
... fft_sizes=[2048, 1024],
... sample_rate=44100,
... )
>>> output = discriminator(torch.randn(1, 1, 44100)) # Random input
>>> print(len(output)) # Should show the number of outputs from discriminators
preprocess(y)
Preprocess the input audio by removing the DC offset and normalizing the volume.
This method performs two main operations on the input audio tensor y:
- It removes the DC offset by subtracting the mean of the audio signal.
- It peak normalizes the volume of the audio to a maximum amplitude of 0.8.
The output is a tensor with the same shape as the input, where the audio is processed for better compatibility with the discriminators.
- Parameters:y (torch.Tensor) – Input audio tensor of shape (batch_size, num_samples).
- Returns: Preprocessed audio tensor of the same shape as input y.
- Return type: torch.Tensor
######### Examples
>>> discriminator = MultiScaleMultiPeriodMultiBandDiscriminator()
>>> audio_input = torch.randn(4, 16000) # Example audio tensor
>>> preprocessed_audio = discriminator.preprocess(audio_input)
>>> print(preprocessed_audio.shape)
torch.Size([4, 16000])
NOTE
The normalization is performed to avoid clipping during processing.