espnet2.gan_svs.avocodo.avocodo.AvocodoDiscriminator
espnet2.gan_svs.avocodo.avocodo.AvocodoDiscriminator
class espnet2.gan_svs.avocodo.avocodo.AvocodoDiscriminator(combd: Dict[str, Any] = {'combd_d_d': [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], 'combd_d_g': [[1, 4, 16, 64, 256, 1], [1, 4, 16, 64, 256, 1], [1, 4, 16, 64, 256, 1]], 'combd_d_k': [[7, 11, 11, 11, 11, 5], [11, 21, 21, 21, 21, 5], [15, 41, 41, 41, 41, 5]], 'combd_d_p': [[3, 5, 5, 5, 5, 2], [5, 10, 10, 10, 10, 2], [7, 20, 20, 20, 20, 2]], 'combd_d_s': [[1, 1, 4, 4, 4, 1], [1, 1, 4, 4, 4, 1], [1, 1, 4, 4, 4, 1]], 'combd_h_u': [[16, 64, 256, 1024, 1024, 1024], [16, 64, 256, 1024, 1024, 1024], [16, 64, 256, 1024, 1024, 1024]], 'combd_op_f': [1, 1, 1], 'combd_op_g': [1, 1, 1], 'combd_op_k': [3, 3, 3]}, sbd: Dict[str, Any] = {'pqmf_config': {'fsbd': [64, 256, 0.1, 9.0], 'sbd': [16, 256, 0.03, 10.0]}, 'sbd_band_ranges': [[0, 6], [0, 11], [0, 16], [0, 64]], 'sbd_dilations': [[[5, 7, 11], [5, 7, 11], [5, 7, 11], [5, 7, 11], [5, 7, 11]], [[3, 5, 7], [3, 5, 7], [3, 5, 7], [3, 5, 7], [3, 5, 7]], [[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3], [1, 2, 3], [2, 3, 5], [2, 3, 5]]], 'sbd_filters': [[64, 128, 256, 256, 256], [64, 128, 256, 256, 256], [64, 128, 256, 256, 256], [32, 64, 128, 128, 128]], 'sbd_kernel_sizes': [[[7, 7, 7], [7, 7, 7], [7, 7, 7], [7, 7, 7], [7, 7, 7]], [[5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5]], [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], [[5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5], [5, 5, 5]]], 'sbd_strides': [[1, 1, 3, 3, 1], [1, 1, 3, 3, 1], [1, 1, 3, 3, 1], [1, 1, 3, 3, 1]], 'sbd_transpose': [False, False, False, True], 'segment_size': 8192, 'use_sbd': True}, pqmf_config: Dict[str, Any] = {'lv1': [2, 256, 0.25, 10.0], 'lv2': [4, 192, 0.13, 10.0]}, projection_filters: List[int] = [0, 1, 1, 1])
Bases: Module
Avocodo Discriminator module.
This class implements the Avocodo Discriminator, which combines Collaborative Multi-band Discriminator (CoMBD) and Sub-band Discriminator (SBD) for enhanced performance in generative adversarial networks. It processes input signals and predicts whether they are real or fake.
pqmf_lv2
PQMF for level 2 processing.
- Type:PQMF
pqmf_lv1
PQMF for level 1 processing.
- Type:PQMF
combd
Collaborative Multi-band Discriminator instance.
- Type:CoMBD
sbd
Sub-band Discriminator instance.
- Type:SBD
projection_filters
Filters for projection layers.
Type: List[int]
Parameters:
- combd (Dict *[*str , Any ]) – Configuration dictionary for CoMBD.
- sbd (Dict *[*str , Any ]) – Configuration dictionary for SBD.
- pqmf_config (Dict *[*str , Any ]) – Configuration dictionary for PQMF.
- projection_filters (List *[*int ]) – List of projection filters for the output.
####### Examples
>>> discriminator = AvocodoDiscriminator()
>>> real_signal = torch.randn(1, 1, 8192) # Batch of real signal
>>> fake_signal = torch.randn(1, 1, 8192) # Batch of fake signal
>>> outs_real, outs_fake, fmaps_real, fmaps_fake = discriminator(
... real_signal, fake_signal
... )
- Returns: Outputs containing real and fake predictions along with feature maps for both.
- Return type: List[List[torch.Tensor]]
NOTE
The discriminator uses spectral normalization if specified in the configuration.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(y: Tensor, y_hats: Tensor) → List[List[Tensor]]
Perform forward propagation through the Avocodo Discriminator.
This method computes the outputs of the discriminator given the real and generated signals. It processes the inputs through multiple layers and combines the outputs from different discriminators.
- Parameters:
- y (torch.Tensor) – The real signals of shape (B, C, T).
- y_hats (torch.Tensor) – The generated signals of shape (B, C, T).
- Returns: A list containing the outputs of the discriminators. Specifically, it returns:
- outs_real: List of output tensors for real signals.
- outs_fake: List of output tensors for generated signals.
- fmaps_real: List of feature maps for real signals.
- fmaps_fake: List of feature maps for generated signals.
- Return type: List[List[torch.Tensor]]
####### Examples
>>> discriminator = AvocodoDiscriminator(...)
>>> real_signals = torch.randn(8, 1, 256) # Batch of real signals
>>> generated_signals = torch.randn(8, 1, 256) # Batch of generated signals
>>> outputs_real, outputs_fake, fmap_real, fmap_fake = discriminator(real_signals, generated_signals)
NOTE
The forward method expects the input tensors to have a specific shape, where B is the batch size, C is the number of channels, and T is the time dimension.