espnet2.gan_codec.shared.discriminator.msstft_discriminator.MultiScaleSTFTDiscriminator
espnet2.gan_codec.shared.discriminator.msstft_discriminator.MultiScaleSTFTDiscriminator
class espnet2.gan_codec.shared.discriminator.msstft_discriminator.MultiScaleSTFTDiscriminator(filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False, n_ffts: List[int] = [1024, 2048, 512], hop_lengths: List[int] = [256, 512, 128], win_lengths: List[int] = [1024, 2048, 512], **kwargs)
Bases: MultiDiscriminator
Multi-Scale STFT (MS-STFT) discriminator.
This class implements a multi-scale Short-Time Fourier Transform (STFT) discriminator, which consists of multiple sub-discriminators operating at different scales. It can be used in Generative Adversarial Networks (GANs) to evaluate the quality of generated audio signals.
sep_channels
If True, separate channels to distinct samples for stereo support.
- Type: bool
discriminators
List of STFT discriminators for each scale.
Type: nn.ModuleList
Parameters:
- filters (int) – Number of filters in convolutions.
- in_channels (int) – Number of input channels.
- out_channels (int) – Number of output channels.
- sep_channels (bool) – Separate channels to distinct samples for stereo support.
- n_ffts (Sequence *[*int ]) – Size of FFT for each scale.
- hop_lengths (Sequence *[*int ]) – Length of hop between STFT windows for each scale.
- win_lengths (Sequence *[*int ]) – Window size for each scale.
- **kwargs – Additional args for STFTDiscriminator.
######### Examples
>>> discriminator = MultiScaleSTFTDiscriminator(filters=64)
>>> input_tensor = torch.randn(1, 1, 16000) # Example input
>>> output = discriminator(input_tensor)
>>> print(len(output)) # Number of scales (discriminators)
####### NOTE The input tensor shape is expected to be (batch_size, channels, time).
- Raises:
- AssertionError – If the lengths of n_ffts, hop_lengths, and
- win_lengths are not equal. –
forward(x: Tensor)
Forward pass for the MultiScaleSTFTDiscriminator.
This method processes the input tensor x through each sub-discriminator, computes the STFT, and returns the feature maps along with the logits from each sub-discriminator.
- Parameters:x (torch.Tensor) – Input tensor of shape [B, C, T], where B is the batch size, C is the number of channels, and T is the number of time steps.
- Returns: A list containing the feature maps and logits from each sub-discriminator. : Each element in the list corresponds to the output of a sub-discriminator, which includes feature maps and the final logit.
- Return type: list
######### Examples
>>> discriminator = MultiScaleSTFTDiscriminator(filters=64)
>>> input_tensor = torch.randn(8, 1, 16000) # Batch of 8, 1 channel, 16000 samples
>>> outputs = discriminator(input_tensor)
>>> for output in outputs:
... print(len(output)) # Each output will have the feature maps and logit
####### NOTE The input tensor should be pre-processed appropriately before being passed to this method, ensuring it has the correct shape and data type.
property num_discriminators
Multi-Scale STFT (MS-STFT) discriminator.
This class implements a multi-scale discriminator that utilizes STFT-based sub-discriminators to analyze audio signals at various scales. Each sub-discriminator processes the input signal using Short-Time Fourier Transform (STFT) to extract features that are relevant for the task at hand, such as audio classification or generation.
- Parameters:
- filters (int) – Number of filters in convolutions.
- in_channels (int) – Number of input channels.
- out_channels (int) – Number of output channels.
- sep_channels (bool) – Separate channels to distinct samples for stereo support.
- n_ffts (Sequence *[*int ]) – Size of FFT for each scale.
- hop_lengths (Sequence *[*int ]) – Length of hop between STFT windows for each scale.
- win_lengths (Sequence *[*int ]) – Window size for each scale.
- **kwargs – Additional args for STFTDiscriminator.
sep_channels
Indicates if channels are separated for stereo support.
- Type: bool
discriminators
List of STFT sub-discriminators.
Type: nn.ModuleList
Returns: List of features extracted from each scale along with logits.
######### Examples
>>> discriminator = MultiScaleSTFTDiscriminator(filters=64)
>>> input_tensor = torch.randn(1, 1, 16000) # Example input
>>> output = discriminator(input_tensor)
>>> len(output) # Should return the number of scales
####### NOTE The number of discriminators is equal to the length of the n_ffts, hop_lengths, and win_lengths parameters.
- Raises:
- AssertionError – If the lengths of n_ffts, hop_lengths,
- and win_lengths do not match. –