espnet2.gan_codec.encodec.encodec.EncodecDiscriminator
espnet2.gan_codec.encodec.encodec.EncodecDiscriminator
class espnet2.gan_codec.encodec.encodec.EncodecDiscriminator(msstft_discriminator_params: Dict[str, Any] = {'activation': 'LeakyReLU', 'activation_params': {'negative_slope: 0.3'}, 'filters': 32, 'hop_lengths': [256, 512, 128, 64, 32], 'in_channels': 1, 'n_fft': [1024, 2048, 512, 256, 128], 'norm': 'weight_norm', 'out_channels': 1, 'win_lengths': [1024, 2048, 512, 256, 128]})
Bases: Module
Encodec Discriminator with only Multi-Scale STFT discriminator module.
This class implements the Encodec Discriminator, which utilizes a Multi-Scale Short-Time Fourier Transform (STFT) for analyzing the input signals. It is designed to work in conjunction with the Encodec model for adversarial training.
msstft
The Multi-Scale STFT discriminator module.
Parameters:msstft_discriminator_params (Dict *[*str , Any ]) –
A dictionary of parameters for initializing the Multi-Scale STFT discriminator with the following keys:
- in_channels (int): Number of input channels.
- out_channels (int): Number of output channels.
- filters (int): Number of filters in convolutions.
- norm (str): Normalization choice of Convolutional layers.
- n_ffts (Sequence[int]): Size of FFT for each scale.
- hop_lengths (Sequence[int]): Length of hop between STFT windows
for each scale.
- win_lengths (Sequence[int]): Window size for each scale.
- activation (str): Activation function choice of convolutional : layer.
- activation_params (Dict[str, Any]): Parameters for the : activation function.
####### Examples
>>> discriminator = EncodecDiscriminator()
>>> input_tensor = torch.randn(8, 1, 1024) # Batch size of 8, 1 channel, 1024 samples
>>> outputs = discriminator(input_tensor)
>>> print(len(outputs)) # Number of scales
>>> print(len(outputs[0])) # Number of outputs for the first scale
- Returns: A list of lists of each discriminator output, which consists of each layer output tensors. Only one discriminator is used here, but the output is structured as a list of lists for consistency.
- Return type: List[List[Tensor]]
- Raises:ValueError – If any of the parameters are invalid during initialization.
Initialize Encodec Discriminator module.
Args: msstft_discriminator_params (Dict[str, Any]) with following arguments: : in_channels (int): Number of input channels. out_channels (int): Number of output channels. filters (int): Number of filters in convolutions. norm (str): normalization choice of Convolutional layers n_ffts (Sequence[int]): Size of FFT for each scale. hop_lengths (Sequence[int]): Length of hop between STFT windows for <br/>
each scale. <br/> win_lengths (Sequence[int]): Window size for each scale. activation (str): activation function choice of convolutional layer activation_params (Dict[str, Any]): parameters for activation function)
forward(x: Tensor) → List[List[Tensor]]
Encodec Discriminator with only Multi-Scale STFT discriminator module.
This class implements the Encodec Discriminator, which utilizes a Multi-Scale Short-Time Fourier Transform (STFT) for evaluating the quality of generated audio signals. The discriminator aims to distinguish between real and generated audio, contributing to the adversarial training process.
msstft
The multi-scale STFT discriminator instance used for feature extraction.
Parameters:msstft_discriminator_params (Dict *[*str , Any ]) –
Parameters for the Multi-Scale STFT Discriminator, including:
- in_channels (int): Number of input channels.
- out_channels (int): Number of output channels.
- filters (int): Number of filters in convolutions.
- norm (str): Normalization choice for convolutional layers.
- n_ffts (Sequence[int]): Sizes of FFT for each scale.
- hop_lengths (Sequence[int]): Length of hop between STFT
windows for each scale.
- win_lengths (Sequence[int]): Window sizes for each scale.
- activation (str): Activation function choice for : convolutional layers.
- activation_params (Dict[str, Any]): Parameters for the : activation function.
####### Examples
>>> discriminator = EncodecDiscriminator()
>>> input_signal = torch.randn(8, 1, 16000) # Batch of 8, 1 channel, 16000 samples
>>> outputs = discriminator(input_signal)
>>> print(len(outputs)) # Number of scales
>>> print(len(outputs[0])) # Number of layers in the first scale