espnet2.gan_codec.shared.discriminator.msstft_discriminator.MultiDiscriminator
espnet2.gan_codec.shared.discriminator.msstft_discriminator.MultiDiscriminator
class espnet2.gan_codec.shared.discriminator.msstft_discriminator.MultiDiscriminator
Bases: ABC
, Module
Base implementation for discriminators composed of sub-discriminators acting at different scales.
This class serves as a base for implementing multi-scale discriminators that utilize several sub-discriminators to analyze input signals at different frequency scales. It defines the interface that derived classes must implement.
None
- Parameters:None
- Returns: None
- Yields: None
- Raises:
- NotImplementedError – If the forward method is not implemented in a derived class.
- NotImplementedError – If the num_discriminators property is not implemented in a derived class.
######### Examples
To create a custom multi-discriminator, derive from this class and implement the forward method and the num_discriminators property:
``
`
python class CustomMultiDiscriminator(MultiDiscriminator):
def forward(self, x: torch.Tensor): : # Custom forward implementation pass
@property def num_discriminators(self) -> int:
return 2 # Example number of discriminators
``
`
####### NOTE This is an abstract class and cannot be instantiated directly.
abstract forward(x: Tensor)
Forward pass for the MultiScaleSTFTDiscriminator.
This method processes the input tensor x through multiple sub-discriminators to produce outputs that capture various frequency and time scales. Each sub-discriminator computes its own feature maps and final logits.
- Parameters:x (torch.Tensor) – Input tensor of shape (B, C, T), where:
- B: Batch size
- C: Number of input channels
- T: Length of the input sequence
- Returns: A list containing feature maps and logits from each sub-discriminator. Each entry in the list corresponds to a sub-discriminator and contains:
- Feature maps from each layer of the sub-discriminator
- Final logits from the last layer of the sub-discriminator
- Return type: List[List[torch.Tensor]]
######### Examples
>>> discriminator = MultiScaleSTFTDiscriminator(filters=64)
>>> input_tensor = torch.randn(8, 1, 16000) # Batch of 8, 1 channel, 16000 samples
>>> output = discriminator(input_tensor)
>>> print(len(output)) # Should print the number of sub-discriminators
>>> print(len(output[0])) # Should print the number of feature maps + 1 for logits
####### NOTE The input tensor should be preprocessed to match the expected shape before passing it to this method. Each sub-discriminator processes the input independently and the outputs are collected into a list.
- Raises:
- RuntimeError – If the input tensor does not have the correct number of
- dimensions or if it contains invalid values. –
abstract property num_discriminators : int
Base implementation for discriminators composed of sub-discriminators acting at different scales.
This class serves as a template for creating multi-scale discriminators, which are composed of several sub-discriminators. Each sub-discriminator operates at a different scale, allowing for a more comprehensive analysis of the input data.
None
- Parameters:None
- Returns: None
- Yields: None
- Raises:
- NotImplementedError – If the forward method is not implemented.
- NotImplementedError – If the num_discriminators property is not
- implemented. –
######### Examples
To create a derived class from MultiDiscriminator, implement the forward method and the num_discriminators property.
class MyDiscriminator(MultiDiscriminator): : def forward(self, x: torch.Tensor): : # Custom forward implementation pass <br/> @property def num_discriminators(self) -> int: <br/>
return 3 # Example number of discriminators
####### NOTE This class cannot be instantiated directly, as it is intended to be a base class for specific multi-discriminator implementations.