espnet2.gan_codec.shared.discriminator.msstft_discriminator.DiscriminatorSTFT
espnet2.gan_codec.shared.discriminator.msstft_discriminator.DiscriminatorSTFT
class espnet2.gan_codec.shared.discriminator.msstft_discriminator.DiscriminatorSTFT(filters: int, in_channels: int = 1, out_channels: int = 1, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, filters_scale: int = 1, kernel_size: Tuple[int, int] = (3, 9), dilations: List = [1, 2, 4], stride: Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2})
Bases: Module
STFT sub-discriminator for evaluating audio signals.
This class implements a short-time Fourier transform (STFT) based discriminator that processes audio input through multiple convolutional layers. It is designed to work as part of a generative adversarial network (GAN) setup, where it assesses the quality of generated audio signals.
filters
Number of filters in convolutions.
- Type: int
in_channels
Number of input channels.
- Type: int
out_channels
Number of output channels.
- Type: int
n_fft
Size of FFT for each scale.
- Type: int
hop_length
Length of hop between STFT windows for each scale.
- Type: int
win_length
Window size for each scale.
- Type: int
normalized
Whether to normalize by magnitude after STFT.
- Type: bool
activation
Activation function used in the convolutional layers.
- Type: callable
convs
List of convolutional layers for feature extraction.
- Type: ModuleList
conv_post
Final convolutional layer to output logits.
Type:NormConv2d
Parameters:
- filters (int) – Number of filters in convolutions.
- in_channels (int) – Number of input channels.
- out_channels (int) – Number of output channels.
- n_fft (int) – Size of FFT for each scale.
- hop_length (int) – Length of hop between STFT windows for each scale.
- kernel_size (tuple of int) – Inner Conv2d kernel sizes.
- stride (tuple of int) – Inner Conv2d strides.
- dilations (list of int) – Inner Conv2d dilation on the time dimension.
- win_length (int) – Window size for each scale.
- normalized (bool) – Whether to normalize by magnitude after STFT.
- norm (str) – Normalization method.
- activation (str) – Activation function.
- activation_params (dict) – Parameters to provide to the activation function.
- growth (int) – Growth factor for the filters.
Returns: A tuple containing the output logits and a list of feature maps from each layer.
Return type: Tuple[torch.Tensor, List[torch.Tensor]]
####### Examples
>>> discriminator = DiscriminatorSTFT(filters=64, in_channels=1)
>>> input_tensor = torch.randn(1, 1, 16000) # Example audio input
>>> output, feature_maps = discriminator(input_tensor)
>>> print(output.shape) # Output logits shape
torch.Size([1, 1, H, W]) # H and W depend on the input and config
NOTE
Ensure the input tensor is of the correct shape (batch_size, in_channels, audio_length) before passing to the forward method.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x: Tensor)
Forward pass for the STFT discriminator.
This method processes the input tensor through the series of convolutional layers defined in the DiscriminatorSTFT class. It transforms the input using the Short-Time Fourier Transform (STFT), applies a series of convolutions, and collects feature maps at each stage. The output is a tuple containing the final output and the list of feature maps.
- Parameters:x (torch.Tensor) – Input tensor of shape [B, C, T] where:
- B is the batch size,
- C is the number of input channels,
- T is the length of the input signal.
- Returns: A tuple containing: : - torch.Tensor: The output tensor after processing through the convolutional layers.
- list: A list of feature maps collected at each layer.
- Return type: tuple
####### Examples
>>> discriminator = DiscriminatorSTFT(filters=64)
>>> input_tensor = torch.randn(8, 1, 1024) # Batch of 8, 1 channel, 1024 length
>>> output, feature_maps = discriminator(input_tensor)
>>> print(output.shape) # Output shape will depend on the architecture
>>> print(len(feature_maps)) # Number of feature maps collected
NOTE
Ensure that the input tensor is properly shaped and normalized as per the requirements of the STFT.
- Raises:
- ValueError – If the input tensor shape is not compatible with the
- expected dimensions. –