espnet2.gan_tts.hifigan.loss.GeneratorAdversarialLoss
espnet2.gan_tts.hifigan.loss.GeneratorAdversarialLoss
class espnet2.gan_tts.hifigan.loss.GeneratorAdversarialLoss(average_by_discriminators: bool = True, loss_type: str = 'mse')
Bases: Module
Generator adversarial loss module.
This module computes the adversarial loss for the generator in a GAN setting. The loss can be calculated using either Mean Squared Error (MSE) or Hinge loss based on the specified configuration during initialization.
average_by_discriminators
Whether to average the loss by the number of discriminators.
- Type: bool
criterion
Loss function used for computing the adversarial loss, either MSE or Hinge.
Type: callable
Parameters:
- average_by_discriminators (bool) – Whether to average the loss by the number of discriminators.
- loss_type (str) – Loss type, either “mse” or “hinge”.
Returns: Generator adversarial loss value.
Return type: Tensor
####### Examples
>>> loss_fn = GeneratorAdversarialLoss(loss_type="mse")
>>> outputs = [torch.tensor([0.5]), torch.tensor([0.8])]
>>> loss = loss_fn(outputs)
>>> print(loss)
- Raises:AssertionError – If an unsupported loss_type is provided.
Initialize GeneratorAversarialLoss module.
- Parameters:
- average_by_discriminators (bool) – Whether to average the loss by the number of discriminators.
- loss_type (str) – Loss type, “mse” or “hinge”.
forward(outputs: List[List[Tensor]] | List[Tensor] | Tensor) → Tensor
Calculate generator adversarial loss.
This method computes the adversarial loss for the generator based on the outputs received from the discriminator(s). The loss is calculated using either mean squared error (MSE) or hinge loss, depending on the specified loss type during initialization.
- Parameters:outputs (Union *[*List *[*List *[*Tensor ] ] , List *[*Tensor ] , Tensor ]) – Discriminator outputs, which can be provided as a list of discriminator outputs, a list of lists of discriminator outputs, or a single tensor.
- Returns: The calculated generator adversarial loss value.
- Return type: Tensor
####### Examples
>>> loss_fn = GeneratorAdversarialLoss(loss_type="mse")
>>> outputs = [torch.tensor([0.5]), torch.tensor([0.3])]
>>> loss = loss_fn(outputs)
>>> print(loss) # Outputs the loss value