espnet2.uasr.loss.pseudo_label_loss.UASRPseudoLabelLoss
espnet2.uasr.loss.pseudo_label_loss.UASRPseudoLabelLoss
class espnet2.uasr.loss.pseudo_label_loss.UASRPseudoLabelLoss(weight: float = 1.0, input_dim: int = 128, output_dim: int = 64, downsample_rate: int = 2, ignore_index: int = -1, reduction: str = 'none')
Bases: AbsUASRLoss
Auxiliary pseudo label loss for Unsupervised Automatic Speech Recognition (UASR).
This loss function computes the pseudo label loss using the cross-entropy between the model’s output and the provided pseudo labels. It is designed to be used in scenarios where labeled data is scarce, allowing the model to learn from its own predictions.
weight
Weight of the loss. If set to 0, the loss will not contribute to the overall loss.
- Type: float
input_dim
Dimensionality of the input features.
- Type: int
output_dim
Dimensionality of the output features.
- Type: int
downsample_rate
Rate at which to downsample the pseudo labels.
- Type: int
ignore_index
Index that is ignored and does not contribute to the loss computation.
- Type: int
reduction
Specifies the reduction to apply to the output. Options are ‘none’, ‘mean’, ‘sum’.
Type: str
Parameters:
- weight (float) – Weight of the loss. Default is 1.0.
- input_dim (int) – Dimensionality of the input features. Default is 128.
- output_dim (int) – Dimensionality of the output features. Default is 64.
- downsample_rate (int) – Rate at which to downsample the pseudo labels. Default is 2.
- ignore_index (int) – Index that is ignored in the loss computation. Default is -1.
- reduction (str) – Specifies the reduction method. Default is ‘none’.
Returns: The computed pseudo label loss.
Return type: torch.Tensor
####### Examples
>>> loss_fn = UASRPseudoLabelLoss(weight=1.0)
>>> inter_x = torch.randn(10, 20, 128) # Batch of 10, seq length 20, input_dim 128
>>> pseudo_labels = torch.randint(0, 64, (10, 10)) # Batch of 10, seq length 10
>>> is_discriminative_step = False
>>> loss = loss_fn(inter_x, pseudo_labels, is_discriminative_step)
>>> print(loss)
NOTE
The loss will return 0 if the weight is 0 or if it is in a discriminative step.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(inter_x: Tensor, pseudo_labels: Tensor, is_discriminative_step: str2bool)
Forward pass for calculating the pseudo label loss.
This method computes the pseudo label loss based on the input tensor and the provided pseudo labels. The loss is calculated using cross entropy, taking into account the downsampling rate and any specified parameters such as the ignore index and reduction method.
- Parameters:
- inter_x (torch.Tensor) – The input tensor from the model’s intermediate layer. Shape should be (batch_size, time_steps, input_dim).
- pseudo_labels (torch.Tensor) – The tensor containing the pseudo labels for the input. Shape should be (batch_size, time_steps).
- is_discriminative_step (str2bool) – A boolean indicating whether the current step is discriminative. If False, the loss will be calculated.
- Returns: The computed pseudo label loss. Returns 0 if the weight is 0, if the step is discriminative, or if pseudo labels are None.
- Return type: torch.Tensor
####### Examples
>>> loss_fn = UASRPseudoLabelLoss(weight=1.0)
>>> inter_x = torch.randn(10, 20, 128) # Batch of 10, 20 time steps
>>> pseudo_labels = torch.randint(0, 64, (10, 10)) # Pseudo labels
>>> is_discriminative_step = False
>>> loss = loss_fn(inter_x, pseudo_labels, is_discriminative_step)
>>> print(loss)
NOTE
Ensure that inter_x and pseudo_labels are appropriately shaped and that the weight parameter is set according to the desired importance of the pseudo label loss in the overall loss computation.