espnet2.gan_codec.shared.loss.freq_loss.MultiScaleMelSpectrogramLoss
espnet2.gan_codec.shared.loss.freq_loss.MultiScaleMelSpectrogramLoss
class espnet2.gan_codec.shared.loss.freq_loss.MultiScaleMelSpectrogramLoss(fs: int = 22050, range_start: int = 6, range_end: int = 11, window: str = 'hann', n_mels: int = 80, fmin: int | None = 0, fmax: int | None = None, center: bool = True, normalized: bool = False, onesided: bool = True, log_base: float | None = 10.0, alphas: bool = True)
Bases: Module
Multi-Scale Mel Spectrogram Loss.
This class implements a multi-scale spectrogram loss for evaluating the quality of generated audio waveforms compared to ground truth. It utilizes the Mel spectrogram representation and can be configured with various parameters to adjust the loss calculation across multiple scales.
alphas
Coefficients for each scale.
- Type: List[float]
total
Total weight for normalizing the loss.
- Type: float
normalized
Indicates whether the loss is normalized.
Type: bool
Parameters:
- fs (int) – Sampling rate. Default is 22050.
- range_start (int) – Power of 2 to use for the first scale. Default is 6.
- range_end (int) – Power of 2 to use for the last scale. Default is 11.
- window (str) – Window type for the FFT. Default is “hann”.
- n_mels (int) – Number of mel bins. Default is 80.
- fmin (Optional *[*int ]) – Minimum frequency for Mel. Default is 0.
- fmax (Optional *[*int ]) – Maximum frequency for Mel. Default is None.
- center (bool) – Whether to use a centered window. Default is True.
- normalized (bool) – Whether to use normalized spectrograms. Default is False.
- onesided (bool) – Whether to use one-sided FFT. Default is True.
- log_base (Optional *[*float ]) – Log base value. Default is 10.0.
- alphas (bool) – Whether to use alphas as coefficients. Default is True.
forward(y_hat
torch.Tensor, y: torch.Tensor) -> torch.Tensor: Calculates the Mel-spectrogram loss between generated and ground truth waveforms.
- Raises:
- AssertionError – If range_end is not greater than range_start or if
- range_start –
####### Examples
>>> loss_fn = MultiScaleMelSpectrogramLoss()
>>> generated = torch.randn(1, 1, 1024) # Example generated waveform
>>> ground_truth = torch.randn(1, 1, 1024) # Example ground truth
>>> loss = loss_fn(generated, ground_truth)
>>> print(loss)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
#
forward(y_hat
Calculate Mel-spectrogram loss between generated and groundtruth waveforms.
This method computes the multi-scale Mel-spectrogram loss by applying the MelSpectrogramLoss to the generated and groundtruth waveforms. The loss is calculated across multiple scales, combining both L1 and MSE losses, weighted by predefined alpha coefficients.
- Parameters:
- y_hat (torch.Tensor) – Generated waveform tensor of shape (B, 1, T), where B is the batch size and T is the number of time steps.
- y (torch.Tensor) – Groundtruth waveform tensor of shape (B, 1, T).
- Returns: A scalar tensor representing the calculated Mel-spectrogram : loss value.
- Return type: torch.Tensor
####### Examples
>>> loss_fn = MultiScaleMelSpectrogramLoss()
>>> generated_waveform = torch.randn(2, 1, 1024) # Example shape
>>> groundtruth_waveform = torch.randn(2, 1, 1024) # Example shape
>>> loss = loss_fn(generated_waveform, groundtruth_waveform)
>>> print(loss)