espnet2.gan_codec.shared.discriminator.stft_discriminator.ComplexSTFTDiscriminator
espnet2.gan_codec.shared.discriminator.stft_discriminator.ComplexSTFTDiscriminator
class espnet2.gan_codec.shared.discriminator.stft_discriminator.ComplexSTFTDiscriminator(*, in_channels=1, channels=32, strides=[[1, 2], [2, 2], [1, 2], [2, 2], [1, 2], [2, 2]], chan_mults=[1, 2, 4, 4, 8, 8], n_fft=1024, hop_length=256, win_length=1024, stft_normalized=False, logits_abs=True)
Bases: Module
ComplexSTFT Discriminator used in SoundStream.
This class implements a complex Short-Time Fourier Transform (STFT) discriminator for use in SoundStream. The architecture consists of several residual units with complex convolutional layers, allowing for effective processing of audio signals in the frequency domain.
Adapted from https://github.com/alibaba-damo-academy/FunCodec.git.
init_conv
Initial complex convolutional layer.
- Type:ComplexConv2d
layers
List of complex STFT residual units.
- Type: nn.ModuleList
stft_normalized
Flag to indicate if STFT output is normalized.
- Type: bool
logits_abs
Flag to determine if the output logits are absolute.
- Type: bool
n_fft
FFT size used in STFT computation.
- Type: int
hop_length
Hop length for STFT.
- Type: int
win_length
Window length for STFT.
Type: int
Parameters:
- in_channels (int) – Input channel (default: 1).
- channels (int) – Number of output channels (default: 32).
- strides (List *[*List *[*int ] ]) – Detailed strides for conv2d modules (default: [[1, 2], [2, 2], [1, 2], [2, 2], [1, 2], [2, 2]]).
- chan_mults (List *[*int ]) – Channel multipliers (default: [1, 2, 4, 4, 8, 8]).
- n_fft (int) – n_fft in the STFT (default: 1024).
- hop_length (int) – hop_length in the STFT (default: 256).
- win_length (int) – win_length in the STFT (default: 1024).
- stft_normalized (bool) – Whether to normalize the STFT output (default: False).
- logits_abs (bool) – Whether to use the absolute number of output logits (default: True).
Returns: None
####### Examples
>>> discriminator = ComplexSTFTDiscriminator()
>>> input_signal = torch.randn(1, 1, 16000) # (B, C, T)
>>> output = discriminator(input_signal)
>>> print(len(output)) # Output: 1
>>> print(output[0][0].shape) # Shape of the discriminator output
NOTE
The implementation is inspired by techniques from audio processing and the referenced papers.
Initialize Complex STFT Discriminator used in SoundStream.
Adapted from https://github.com/alibaba-damo-academy/FunCodec.git
- Parameters:
- in_channels (int) – Input channel.
- channels (int) – Output channel.
- strides (List *[*List *(*int , int ) ]) – detailed strides in conv2d modules.
- chan_mults (List *[*int ]) – Channel multiplers.
- n_fft (int) – n_fft in the STFT.
- hop_length (int) – hop_length in the STFT.
- stft_normalized (bool) – whether to normalize the stft output.
- logits_abs (bool) – whether to use the absolute number of output logits.
forward(x)
ComplexSTFTDiscriminator is a neural network module that implements a complex
Short-Time Fourier Transform (STFT) discriminator used in SoundStream. It processes input signals to produce a list of outputs suitable for further discrimination tasks.
init_conv
Initial complex convolution layer.
- Type:ComplexConv2d
layers
List of complex STFT residual units.
- Type: ModuleList
stft_normalized
Flag indicating whether to normalize the STFT output.
- Type: bool
logits_abs
Flag indicating whether to return absolute logits.
- Type: bool
n_fft
FFT size for the STFT.
- Type: int
hop_length
Hop length for the STFT.
- Type: int
win_length
Window length for the STFT.
Type: int
Parameters:
- in_channels (int) – Number of input channels (default: 1).
- channels (int) – Number of output channels (default: 32).
- strides (List *[*List *[*int ] ]) – Strides for the convolutional layers (default: [[1, 2], [2, 2], [1, 2], [2, 2], [1, 2], [2, 2]]).
- chan_mults (List *[*int ]) – Channel multipliers for each layer (default: [1, 2, 4, 4, 8, 8]).
- n_fft (int) – FFT size for the STFT (default: 1024).
- hop_length (int) – Hop length for the STFT (default: 256).
- win_length (int) – Window length for the STFT (default: 1024).
- stft_normalized (bool) – Whether to normalize the STFT output (default: False).
- logits_abs (bool) – Whether to return absolute values of output logits (default: True).
Returns: A list of lists containing the discriminator output.
Return type: List[List[Tensor]]
####### Examples
Example usage of ComplexSTFTDiscriminator
discriminator = ComplexSTFTDiscriminator() input_signal = torch.randn(1, 1, 1024) # Batch size of 1, 1 channel, 1024 time steps output = discriminator(input_signal) print(output)
NOTE
This module is adapted from the implementation found at https://github.com/alibaba-damo-academy/FunCodec.git.
Reference: : Paper: https://arxiv.org/pdf/2107.03312.pdf Implementation: https://github.com/alibaba-damo-academy/FunCodec.git