espnet2.svs.naive_rnn.naive_rnn.NaiveRNNLoss
espnet2.svs.naive_rnn.naive_rnn.NaiveRNNLoss
class espnet2.svs.naive_rnn.naive_rnn.NaiveRNNLoss(use_masking=True, use_weighted_masking=False)
Bases: Module
NaiveRNNLoss is a loss function module for Tacotron2 that computes the
L1 and mean square error (MSE) losses based on the model’s outputs and target features. It allows for optional masking of padded parts in the loss calculation, which can help improve training efficiency and accuracy.
use_masking
Whether to apply masking for padded parts in loss calculation.
- Type: bool
use_weighted_masking
Whether to apply weighted masking in loss calculation.
- Type: bool
l1_criterion
L1 loss criterion.
- Type: torch.nn.L1Loss
mse_criterion
MSE loss criterion.
Type: torch.nn.MSELoss
Parameters:
- use_masking (bool) – Flag to indicate if masking should be applied to the padded parts of the inputs.
- use_weighted_masking (bool) – Flag to indicate if weighted masking should be applied during loss calculation.
Returns: A tuple containing: : - L1 loss value (Tensor).
- Mean square error loss value (Tensor).
Return type: Tuple[Tensor, Tensor]
####### Examples
loss_fn = NaiveRNNLoss(use_masking=True, use_weighted_masking=False) l1_loss, mse_loss = loss_fn(after_outs, before_outs, ys, olens)
- Raises:ValueError – If both use_masking and use_weighted_masking are True, or if they are both False.
Initialize Tactoron2 loss module.
- Parameters:
- use_masking (bool) – Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool) – Whether to apply weighted masking in loss calculation.
forward(after_outs, before_outs, ys, olens)
Calculate forward propagation.
This method computes the L1 and mean square error (MSE) losses between the predicted outputs (after and before postnet) and the target features, applying optional masking for padded elements based on the specified parameters.
- Parameters:
- after_outs (Tensor) – Batch of outputs after postnets (B, Lmax, odim).
- before_outs (Tensor) – Batch of outputs before postnets (B, Lmax, odim).
- ys (Tensor) – Batch of padded target features (B, Lmax, odim).
- olens (LongTensor) – Batch of the lengths of each target (B,).
- Returns: L1 loss value. Tensor: Mean square error loss value.
- Return type: Tensor
####### Examples
>>> loss_module = NaiveRNNLoss(use_masking=True)
>>> after_outs = torch.randn(4, 10, 80) # Example output
>>> before_outs = torch.randn(4, 10, 80) # Example output
>>> ys = torch.randn(4, 10, 80) # Example target
>>> olens = torch.tensor([10, 10, 8, 6]) # Example lengths
>>> l1_loss, mse_loss = loss_module(after_outs, before_outs, ys, olens)
>>> print(l1_loss, mse_loss)
NOTE
The masking process is only applied if use_masking is set to True. If use_weighted_masking is True, the losses are weighted based on the mask.
- Raises:
- AssertionError – If olens is not properly defined or if
- use_masking –