espnet2.uasr.loss.gradient_penalty.UASRGradientPenalty
espnet2.uasr.loss.gradient_penalty.UASRGradientPenalty
class espnet2.uasr.loss.gradient_penalty.UASRGradientPenalty(discriminator: AbsDiscriminator, weight: float = 1.0, probabilistic_grad_penalty_slicing: str2bool = False, reduction: str = 'sum')
Bases: AbsUASRLoss
Gradient penalty for Unsupervised Audio Speech Recognition (UASR).
This class implements a gradient penalty mechanism to be used in the training of a discriminator in UASR tasks. The gradient penalty helps to enforce Lipschitz continuity, which is crucial for training Generative Adversarial Networks (GANs) effectively.
discriminator
The discriminator model.
- Type: List[AbsDiscriminator]
weight
The weight applied to the gradient penalty.
- Type: float
probabilistic_grad_penalty_slicing
If True, uses probabilistic slicing for samples.
- Type: bool
reduction
Specifies the reduction method to apply to the output (‘sum’ or ‘mean’).
Type: str
Parameters:
- discriminator (AbsDiscriminator) – The discriminator model.
- weight (float) – Weight for the gradient penalty (default is 1.0).
- probabilistic_grad_penalty_slicing (str2bool) – Flag for probabilistic gradient penalty slicing (default is False).
- reduction (str) – Reduction method (‘sum’ or ‘mean’, default is ‘sum’).
Returns: The computed gradient penalty value.
Return type: torch.Tensor
####### Examples
>>> discriminator = SomeDiscriminatorModel()
>>> loss_fn = UASRGradientPenalty(discriminator, weight=10.0)
>>> fake_samples = torch.randn(16, 100) # Batch of fake samples
>>> real_samples = torch.randn(16, 100) # Batch of real samples
>>> loss = loss_fn(fake_samples, real_samples, True, True)
>>> print(loss)
NOTE
The is_training argument should be set to True when the model is in training mode, and is_discrimininative_step should be set to True when the discriminator is being trained.
- Raises:ValueError – If the shapes of fake_sample and real_sample do not match.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(fake_sample: Tensor, real_sample: Tensor, is_training: str2bool, is_discrimininative_step: str2bool)
Computes the gradient penalty for UASR during training.
This method calculates the gradient penalty as part of the UASR loss function, which is crucial for training the discriminator in a GAN-like setup. The gradient penalty helps enforce the Lipschitz constraint, ensuring the model is more stable during training.
- Parameters:
- fake_sample (torch.Tensor) – A tensor representing the generated sample from the generator.
- real_sample (torch.Tensor) – A tensor representing the real sample.
- is_training (str2bool) – A boolean indicating whether the model is currently in the training phase.
- is_discrimininative_step (str2bool) – A boolean indicating whether the current step is focused on training the discriminator.
- Returns: The computed gradient penalty value, which is the sum of the squared norm of the gradients, or zero if the conditions for computing the penalty are not met.
- Return type: torch.Tensor
NOTE
The method uses probabilistic slicing if probabilistic_grad_penalty_slicing is set to True. Otherwise, it slices the samples based on the batch size and time length directly.
####### Examples
>>> fake_samples = torch.randn(32, 100) # Batch of fake samples
>>> real_samples = torch.randn(32, 100) # Batch of real samples
>>> loss = uasr_gradient_penalty.forward(
... fake_samples, real_samples,
... is_training=True, is_discrimininative_step=True)
>>> print(loss)