espnet2.gan_tts.hifigan.loss.FeatureMatchLoss
espnet2.gan_tts.hifigan.loss.FeatureMatchLoss
class espnet2.gan_tts.hifigan.loss.FeatureMatchLoss(average_by_layers: bool = True, average_by_discriminators: bool = True, include_final_outputs: bool = False)
Bases: Module
Feature matching loss module.
This module calculates the feature matching loss used in generative adversarial networks (GANs) to ensure that the generated outputs have similar features to the ground truth outputs. The loss can be averaged across layers and discriminators, and it can optionally include the final outputs of the discriminators for loss calculation.
average_by_layers
Whether to average the loss by the number of layers.
- Type: bool
average_by_discriminators
Whether to average the loss by the number of discriminators.
- Type: bool
include_final_outputs
Whether to include the final output of each discriminator for loss calculation.
Type: bool
Parameters:
- average_by_layers (bool) – Whether to average the loss by the number of layers.
- average_by_discriminators (bool) – Whether to average the loss by the number of discriminators.
- include_final_outputs (bool) – Whether to include the final output of each discriminator for loss calculation.
Returns: Feature matching loss value.
Return type: Tensor
Examples
>>> feature_match_loss = FeatureMatchLoss()
>>> feats_hat = [torch.randn(2, 80, 100), torch.randn(2, 80, 100)]
>>> feats = [torch.randn(2, 80, 100), torch.randn(2, 80, 100)]
>>> loss = feature_match_loss(feats_hat, feats)
>>> print(loss)
Initialize FeatureMatchLoss module.
- Parameters:
- average_by_layers (bool) – Whether to average the loss by the number of layers.
- average_by_discriminators (bool) – Whether to average the loss by the number of discriminators.
- include_final_outputs (bool) – Whether to include the final output of each discriminator for loss calculation.
forward(feats_hat: List[List[Tensor]] | List[Tensor], feats: List[List[Tensor]] | List[Tensor]) → Tensor
Calculate generator adversarial loss.
- Parameters:outputs (Union *[*List *[*List *[*Tensor ] ] , List *[*Tensor ] , Tensor ]) – Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs.
- Returns: Generator adversarial loss value.
- Return type: Tensor