espnet2.gan_tts.jets.loss.VarianceLoss
espnet2.gan_tts.jets.loss.VarianceLoss
class espnet2.gan_tts.jets.loss.VarianceLoss(use_masking: bool = True, use_weighted_masking: bool = False)
Bases: Module
VarianceLoss is a PyTorch module that computes the variance loss for JETS,
which includes duration, pitch, and energy predictions. It allows for masking of padded sequences and supports weighted masking to adjust the loss based on the importance of different input elements.
use_masking
Flag indicating whether to apply masking to the loss calculation.
- Type: bool
use_weighted_masking
Flag indicating whether to use weighted masking in the loss calculation.
- Type: bool
mse_criterion
Mean Squared Error loss function.
- Type: torch.nn.MSELoss
duration_criterion
Duration prediction loss function.
Parameters:
- use_masking (bool) – Whether to apply masking for padded part in loss calculation. Defaults to True.
- use_weighted_masking (bool) – Whether to apply weighted masking in loss calculation. Defaults to False.
Returns: The calculated loss values : for duration, pitch, and energy predictions.
Return type: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
####### Examples
>>> variance_loss = VarianceLoss(use_masking=True, use_weighted_masking=False)
>>> d_outs = torch.randn(2, 5) # Example duration outputs
>>> ds = torch.randint(1, 6, (2, 5)) # Example durations
>>> p_outs = torch.randn(2, 5, 1) # Example pitch outputs
>>> ps = torch.randn(2, 5, 1) # Example target pitch
>>> e_outs = torch.randn(2, 5, 1) # Example energy outputs
>>> es = torch.randn(2, 5, 1) # Example target energy
>>> ilens = torch.tensor([5, 5]) # Example input lengths
>>> duration_loss, pitch_loss, energy_loss = variance_loss(d_outs, ds, p_outs, ps, e_outs, es, ilens)
NOTE
The variance loss is specifically designed for the JETS framework in ESPnet2 and should be used in conjunction with other components of the JETS model.
Initialize JETS variance loss module.
- Parameters:
- use_masking (bool) – Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool) – Whether to weighted masking in loss calculation.
forward(d_outs: Tensor, ds: Tensor, p_outs: Tensor, ps: Tensor, e_outs: Tensor, es: Tensor, ilens: Tensor) → Tuple[Tensor, Tensor, Tensor, Tensor]
Calculate forward propagation.
This method computes the variance loss for the duration, pitch, and energy predictors based on the provided outputs and targets. It applies masking to ignore padded parts of the input during loss calculation if specified.
- Parameters:
- d_outs (LongTensor) – Batch of outputs of duration predictor (B, T_text).
- ds (LongTensor) – Batch of durations (B, T_text).
- p_outs (Tensor) – Batch of outputs of pitch predictor (B, T_text, 1).
- ps (Tensor) – Batch of target token-averaged pitch (B, T_text, 1).
- e_outs (Tensor) – Batch of outputs of energy predictor (B, T_text, 1).
- es (Tensor) – Batch of target token-averaged energy (B, T_text, 1).
- ilens (LongTensor) – Batch of the lengths of each input (B,).
- Returns: A tuple containing the duration predictor loss value, pitch predictor loss value, and energy predictor loss value.
- Return type: Tuple[Tensor, Tensor, Tensor, Tensor]
####### Examples
>>> variance_loss = VarianceLoss(use_masking=True)
>>> d_outs = torch.randn(2, 5)
>>> ds = torch.randint(1, 10, (2, 5))
>>> p_outs = torch.randn(2, 5, 1)
>>> ps = torch.randn(2, 5, 1)
>>> e_outs = torch.randn(2, 5, 1)
>>> es = torch.randn(2, 5, 1)
>>> ilens = torch.tensor([5, 5])
>>> losses = variance_loss(d_outs, ds, p_outs, ps, e_outs, es, ilens)
>>> print(losses)