espnet2.tts.prodiff.loss.SSimLoss
espnet2.tts.prodiff.loss.SSimLoss
class espnet2.tts.prodiff.loss.SSimLoss(bias: float = 6.0, window_size: int = 11, channels: int = 1, reduction: str = 'none')
Bases: Module
SSimLoss.
This is an implementation of structural similarity (SSIM) loss. The SSIM loss measures the similarity between two images, with a focus on luminance, contrast, and structure, which is often used in image processing tasks.
This code is modified from https://github.com/Po-Hsun-Su/pytorch-ssim.
bias
Value of the bias added to the outputs and target.
- Type: float
win_len
Size of the SSIM window.
- Type: int
channels
Number of channels in the input tensors.
- Type: int
average
Flag to determine if the loss should be averaged.
Type: bool
Parameters:
- bias (float , optional) – Value of the bias. Defaults to 6.0.
- window_size (int , optional) – Window size for SSIM calculation. Defaults to 11.
- channels (int , optional) – Number of channels in the input. Defaults to 1.
- reduction (str , optional) – Type of reduction during the loss calculation. Can be “none”, “mean”. Defaults to “none”.
Returns: Loss scalar value.
Return type: Tensor
######### Examples
>>> ssim_loss = SSimLoss()
>>> output = torch.randn(1, 1, 256, 256)
>>> target = torch.randn(1, 1, 256, 256)
>>> loss = ssim_loss(output, target)
>>> print(loss)
- Raises:ValueError – If the input tensors do not have the same shape.
Initialization.
- Parameters:
- bias (float , optional) – value of the bias. Defaults to 6.0.
- window_size (int , optional) – Window size. Defaults to 11.
- channels (int , optional) – Number of channels. Defaults to 1.
- reduction (str , optional) – Type of reduction during the loss calculation. Defaults to “none”.
forward(outputs: Tensor, target: Tensor)
Calculate forward propagation.
This method computes the loss values for the ProDiff loss function. It takes various outputs from the model, as well as the target values, and computes the L1 loss, duration predictor loss, pitch predictor loss, and energy predictor loss. The method can also apply masking to ignore padded parts of the sequences if specified.
- Parameters:
- after_outs (Tensor) – Batch of outputs after postnets (B, T_feats, odim).
- before_outs (Tensor) – Batch of outputs before postnets (B, T_feats, odim).
- d_outs (LongTensor) – Batch of outputs of duration predictor (B, T_text).
- p_outs (Tensor) – Batch of outputs of pitch predictor (B, T_text, 1).
- e_outs (Tensor) – Batch of outputs of energy predictor (B, T_text, 1).
- ys (Tensor) – Batch of target features (B, T_feats, odim).
- ds (LongTensor) – Batch of durations (B, T_text).
- ps (Tensor) – Batch of target token-averaged pitch (B, T_text, 1).
- es (Tensor) – Batch of target token-averaged energy (B, T_text, 1).
- ilens (LongTensor) – Batch of the lengths of each input (B,).
- olens (LongTensor) – Batch of the lengths of each target (B,).
- Returns:
- L1 loss value.
- Duration predictor loss value.
- Pitch predictor loss value.
- Energy predictor loss value.
- Return type: Tuple[Tensor, Tensor, Tensor, Tensor]
######### Examples
>>> after_outs = torch.rand(2, 100, 80)
>>> before_outs = torch.rand(2, 100, 80)
>>> d_outs = torch.randint(1, 10, (2, 20))
>>> p_outs = torch.rand(2, 20, 1)
>>> e_outs = torch.rand(2, 20, 1)
>>> ys = torch.rand(2, 100, 80)
>>> ds = torch.randint(1, 10, (2, 20))
>>> ps = torch.rand(2, 20, 1)
>>> es = torch.rand(2, 20, 1)
>>> ilens = torch.tensor([100, 90])
>>> olens = torch.tensor([100, 90])
>>> l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss = loss_module.forward(
... after_outs, before_outs, d_outs, p_outs, e_outs, ys, ds, ps, es, ilens, olens
... )
NOTE
This method assumes that the inputs are properly shaped and on the same device.
ssim(tensor1: Tensor, tensor2: Tensor)
SSimLoss.
This class implements the structural similarity (SSIM) loss function, which measures the similarity between two images. This code is modified from https://github.com/Po-Hsun-Su/pytorch-ssim.
bias
The bias value added to the outputs and targets.
- Type: float
win_len
The size of the Gaussian window used in SSIM calculation.
- Type: int
channels
The number of channels in the input tensors.
- Type: int
average
Indicates whether to average the SSIM values.
Type: bool
Parameters:
- bias (float , optional) – Value of the bias. Defaults to 6.0.
- window_size (int , optional) – Size of the Gaussian window. Defaults to 11.
- channels (int , optional) – Number of channels. Defaults to 1.
- reduction (str , optional) – Type of reduction during the loss calculation. Defaults to “none”.
Returns: The calculated SSIM loss.
Return type: Tensor
######### Examples
>>> ssim_loss = SSimLoss()
>>> output = torch.randn(1, 1, 256, 256)
>>> target = torch.randn(1, 1, 256, 256)
>>> loss = ssim_loss(output, target)
>>> print(loss)
- Raises:ValueError – If the input tensors do not have the same shape.
NOTE
This loss is particularly useful for tasks involving image quality assessment, such as image denoising and super-resolution.