espnet2.tts.prodiff.loss.ProDiffLoss
espnet2.tts.prodiff.loss.ProDiffLoss
class espnet2.tts.prodiff.loss.ProDiffLoss(use_masking: bool = True, use_weighted_masking: bool = False)
Bases: Module
Loss function module for ProDiffLoss.
This module implements the ProDiffLoss, which is used for training models in the ESPnet2 text-to-speech framework. It combines multiple loss criteria to optimize the performance of the model, including L1 loss, SSIM loss, duration prediction loss, pitch prediction loss, and energy prediction loss.
use_masking
Whether to apply masking for padded parts in loss calculation.
- Type: bool
use_weighted_masking
Whether to apply weighted masking in loss calculation.
Type: bool
Parameters:
- use_masking (bool) – Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool) – Whether to apply weighted masking in loss calculation.
Raises:AssertionError – If both use_masking and use_weighted_masking are True.
####### Examples
>>> loss_fn = ProDiffLoss(use_masking=True, use_weighted_masking=False)
>>> l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss = loss_fn(
... after_outs, before_outs, d_outs, p_outs, e_outs, ys, ds, ps, es,
... ilens, olens)
Initialize feed-forward Transformer loss module.
- Parameters:
- use_masking (bool) – Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool) – Whether to weighted masking in loss calculation.
forward(after_outs: Tensor, before_outs: Tensor, d_outs: Tensor, p_outs: Tensor, e_outs: Tensor, ys: Tensor, ds: Tensor, ps: Tensor, es: Tensor, ilens: Tensor, olens: Tensor) → Tuple[Tensor, Tensor, Tensor, Tensor]
Calculate forward propagation.
This method computes the loss values based on the outputs from the model and the corresponding target values. It utilizes various loss functions such as L1 loss, SSIM loss, and others, depending on the specified configuration for masking and weighted masking.
- Parameters:
- after_outs (torch.Tensor) – Batch of outputs after postnets (B, T_feats, odim).
- before_outs (torch.Tensor) – Batch of outputs before postnets (B, T_feats, odim).
- d_outs (torch.LongTensor) – Batch of outputs of duration predictor (B, T_text).
- p_outs (torch.Tensor) – Batch of outputs of pitch predictor (B, T_text, 1).
- e_outs (torch.Tensor) – Batch of outputs of energy predictor (B, T_text, 1).
- ys (torch.Tensor) – Batch of target features (B, T_feats, odim).
- ds (torch.LongTensor) – Batch of durations (B, T_text).
- ps (torch.Tensor) – Batch of target token-averaged pitch (B, T_text, 1).
- es (torch.Tensor) – Batch of target token-averaged energy (B, T_text, 1).
- ilens (torch.LongTensor) – Batch of the lengths of each input (B,).
- olens (torch.LongTensor) – Batch of the lengths of each target (B,).
- Returns:
- L1 loss value.
- SSIM loss value.
- Duration predictor loss value.
- Pitch predictor loss value.
- Energy predictor loss value.
- Return type: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
####### Examples
Example usage of the forward method
l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss = loss_module.forward(after_outs, before_outs, d_outs, p_outs, e_outs,
ys, ds, ps, es, ilens, olens)
NOTE
The method can apply masking for padded parts based on the configuration specified during the initialization of the class.