espnet2.uasr.loss.phoneme_diversity_loss.UASRPhonemeDiversityLoss
espnet2.uasr.loss.phoneme_diversity_loss.UASRPhonemeDiversityLoss
class espnet2.uasr.loss.phoneme_diversity_loss.UASRPhonemeDiversityLoss(weight: float = 1.0)
Bases: AbsUASRLoss
Phoneme diversity loss for UASR (Unsupervised Automatic Speech Recognition).
This loss function encourages diversity in the phoneme predictions made by the model, which can help improve the overall performance of the UASR system.
weight
Weighting factor for the loss, default is 1.0.
Type: float
Parameters:weight (float) – A scalar to scale the contribution of the loss.
Returns: The computed phoneme diversity loss, which is a scalar value.
Return type: torch.Tensor
Raises:ValueError – If dense_x is not a 3D tensor.
####### Examples
>>> loss_function = UASRPhonemeDiversityLoss(weight=0.5)
>>> dense_x = torch.randn(10, 20, 30) # (batch_size, time_length, channels)
>>> sample_size = 10
>>> is_discriminative_step = False
>>> loss = loss_function(dense_x, sample_size, is_discriminative_step)
>>> print(loss)
NOTE
The loss is computed only when weight is greater than 0 and is_discriminative_step is False. Otherwise, the loss will return 0.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(dense_x: Tensor, sample_size: int, is_discriminative_step: str2bool)
Computes the phoneme diversity loss for Unsupervised Automatic Speech Recognition (UASR).
This loss encourages diversity in the phonemes generated by the model by penalizing low-entropy distributions in the predicted logits. It is typically used during the training phase when the model is not in a discriminative step.
- Parameters:
- dense_x (torch.Tensor) – Predicted logits of generated samples with shape (batch_size, time_length, channel_size).
- sample_size (int) – The batch size used in the current training iteration.
- is_discriminative_step (str2bool) – A boolean indicating whether the model is currently in a discriminative training step. If True, the diversity loss will not be calculated.
- Returns: The computed phoneme diversity loss. Returns 0 if the weight is less than or equal to 0 or if in a discriminative training step.
- Return type: torch.Tensor
####### Examples
>>> loss_fn = UASRPhonemeDiversityLoss(weight=1.0)
>>> logits = torch.randn(4, 10, 20) # Example logits for a batch
>>> loss = loss_fn.forward(logits, sample_size=4,
... is_discriminative_step=False)
>>> print(loss)
NOTE
The loss calculation is performed only if the weight is greater than 0 and the model is not in a discriminative step.