espnet2.gan_tts.hifigan.loss.DiscriminatorAdversarialLoss
espnet2.gan_tts.hifigan.loss.DiscriminatorAdversarialLoss
class espnet2.gan_tts.hifigan.loss.DiscriminatorAdversarialLoss(average_by_discriminators: bool = True, loss_type: str = 'mse')
Bases: Module
Discriminator adversarial loss module.
This module computes the adversarial loss for the discriminator in a GAN setup. It can handle both “mse” and “hinge” loss types and allows for averaging across multiple discriminators.
average_by_discriminators
Whether to average the loss by the number of discriminators.
- Type: bool
loss_type
Loss type, either “mse” or “hinge”.
Type: str
Parameters:
- average_by_discriminators (bool) – Whether to average the loss by the number of discriminators.
- loss_type (str) – Loss type, “mse” or “hinge”.
Returns: Discriminator real loss value and discriminator fake loss value.
Return type: Tuple[torch.Tensor, torch.Tensor]
####### Examples
>>> discriminator_loss = DiscriminatorAdversarialLoss(loss_type='hinge')
>>> outputs_hat = [torch.tensor([0.9, 0.1]), torch.tensor([0.8, 0.2])]
>>> outputs = [torch.tensor([1.0, 0.0]), torch.tensor([1.0, 0.0])]
>>> real_loss, fake_loss = discriminator_loss(outputs_hat, outputs)
- Raises:AssertionError – If the provided loss_type is not “mse” or “hinge”.
Initialize DiscriminatorAversarialLoss module.
- Parameters:
- average_by_discriminators (bool) – Whether to average the loss by the number of discriminators.
- loss_type (str) – Loss type, “mse” or “hinge”.
forward(outputs_hat: List[List[Tensor]] | List[Tensor] | Tensor, outputs: List[List[Tensor]] | List[Tensor] | Tensor) → Tuple[Tensor, Tensor]
Calculate discriminator adversarial loss.
This method computes the adversarial loss for the discriminator by comparing the outputs from the generator and the ground truth outputs. The loss is computed separately for real and fake outputs based on the specified loss type (MSE or hinge).
- Parameters:
- outputs_hat (Union *[*List *[*List *[*Tensor ] ] , List *[*Tensor ] , Tensor ]) – Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs calculated from the generator.
- outputs (Union *[*List *[*List *[*Tensor ] ] , List *[*Tensor ] , Tensor ]) – Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs calculated from the ground truth.
- Returns: A tuple containing the discriminator real loss value and the discriminator fake loss value.
- Return type: Tuple[torch.Tensor, torch.Tensor]
####### Examples
>>> discriminator_loss = DiscriminatorAdversarialLoss()
>>> real_outputs = [torch.tensor([[0.9], [0.8]])]
>>> fake_outputs = [torch.tensor([[0.2], [0.1]])]
>>> real_loss, fake_loss = discriminator_loss(fake_outputs, real_outputs)
>>> print(real_loss, fake_loss)