espnet2.gan_codec.shared.discriminator.msmpmb_discriminator.MultiBandDiscriminator
espnet2.gan_codec.shared.discriminator.msmpmb_discriminator.MultiBandDiscriminator
class espnet2.gan_codec.shared.discriminator.msmpmb_discriminator.MultiBandDiscriminator(window_length: int, hop_factor: float = 0.25, sample_rate: int = 44100, bands: list = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)], channel: int = 32)
Bases: Module
MultiBandDiscriminator is a complex multi-band spectrogram discriminator that
operates on audio signals to classify them based on their frequency bands.
window_length
Window length of the Short-Time Fourier Transform (STFT).
- Type: int
hop_factor
Hop factor of the STFT.
- Type: float
sample_rate
Sampling rate of audio in Hz.
- Type: int
bands
Frequency bands for the discriminator.
- Type: list
n_fft
Number of FFT points.
- Type: int
hop_length
Length of hop between STFT windows.
- Type: int
band_convs
List of convolutional layers for each frequency band.
- Type: ModuleList
conv_post
Final convolutional layer applied to the concatenated outputs.
Type: Module
Parameters:
- window_length (int) – Window length of STFT.
- hop_factor (float , optional) – Hop factor of the STFT, defaults to 0.25 * window_length.
- sample_rate (int , optional) – Sampling rate of audio in Hz, defaults to 44100.
- bands (list , optional) – Bands to run discriminator over, defaults to [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)].
- channel (int , optional) – Number of channels in the convolutional layers, defaults to 32.
spectrogram(x)
Computes the spectrogram of the input audio and splits it into defined frequency bands.
forward(x)
Processes the input audio through the discriminator and returns feature maps.
######### Examples
>>> discriminator = MultiBandDiscriminator(window_length=1024)
>>> audio_input = torch.randn(1, 1, 44100) # Example audio tensor
>>> features = discriminator(audio_input)
####### NOTE This class is part of a larger GAN architecture and is used to distinguish real from generated audio based on multi-band spectrograms.
Complex multi-band spectrogram discriminator. :param window_length: Window length of STFT. :type window_length: int :param hop_factor: Hop factor of the STFT, defaults to 0.25 * window_length
. :type hop_factor: float, optional :param sample_rate: Sampling rate of audio in Hz, by default 44100 :type sample_rate: int, optional :param bands: Bands to run discriminator over. :type bands: list, optional
forward(x)
Forward pass for the MultiBandDiscriminator.
This method processes the input tensor x through the discriminator’s architecture, generating feature maps from multiple frequency bands. The input audio is first converted into a spectrogram, which is then split into defined frequency bands. Each band is processed through a series of convolutional layers, and the outputs are collected into a list of feature maps.
- Parameters:x (torch.Tensor) – Input tensor of shape (batch_size, 1, time) representing the audio signal.
- Returns: A list of feature maps, where each feature map : corresponds to the output from the convolutional layers applied to the respective frequency band, as well as the final output after the post-processing convolution.
- Return type: List[torch.Tensor]
######### Examples
>>> discriminator = MultiBandDiscriminator(window_length=1024)
>>> input_audio = torch.randn(1, 1, 2048) # Example audio input
>>> output_feature_maps = discriminator.forward(input_audio)
>>> print(len(output_feature_maps)) # Number of bands + 1 for post-conv
####### NOTE The input audio should be a mono signal with a shape of (batch_size, 1, time).
- Raises:ValueError – If the input tensor does not have the correct shape.
spectrogram(x)
Complex multi-band spectrogram discriminator.
This class implements a multi-band discriminator for audio signals, which processes the input signal in different frequency bands using convolutional layers. It is designed for use in generative adversarial networks (GANs) for tasks like speech synthesis.
window_length
Window length of STFT.
- Type: int
hop_factor
Hop factor of the STFT, defaults to 0.25 * window_length
.
- Type: float
sample_rate
Sampling rate of audio in Hz, by default 44100.
- Type: int
bands
List of frequency bands to run the discriminator over.
- Type: list
n_fft
Number of FFT points.
- Type: int
hop_length
Length of hop between STFT windows.
- Type: int
band_convs
List of convolutional layers for each band.
- Type: ModuleList
conv_post
Post-processing convolutional layer.
Type: Module
Parameters:
- window_length (int) – Length of the STFT window.
- hop_factor (float , optional) – Factor determining the hop length, defaults to 0.25.
- sample_rate (int , optional) – Audio sampling rate in Hz, defaults to 44100.
- bands (list , optional) – List of frequency bands for the discriminator, defaults to [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)].
- channel (int , optional) – Number of channels for the convolutional layers, defaults to 32.
Returns: A list of feature maps from the discriminator.
Return type: List[Tensor]
######### Examples
>>> discriminator = MultiBandDiscriminator(window_length=1024)
>>> input_signal = torch.randn(1, 1, 44100) # Batch size 1, mono
>>> output = discriminator(input_signal)
>>> len(output) # Should match the number of bands
####### NOTE This class relies on the PyTorch and torchaudio libraries for tensor operations and audio processing.