espnet2.enh.loss.criterions.time_domain.SNRLoss
espnet2.enh.loss.criterions.time_domain.SNRLoss
class espnet2.enh.loss.criterions.time_domain.SNRLoss(eps=1.1920928955078125e-07, name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: TimeDomainLoss
SNR (Signal-to-Noise Ratio) loss for time-domain enhancement.
This loss computes the negative signal-to-noise ratio (SNR) between the reference signal and the estimated signal. The SNR is calculated in decibels, where a higher SNR indicates better performance.
- Parameters:
- eps (float) – A small constant added to the denominator to avoid division by zero (default: machine epsilon).
- name (str , optional) – Name of the loss function (default: “snr_loss”).
- only_for_test (bool , optional) – If True, this loss is only used during testing (default: False).
- is_noise_loss (bool , optional) – If True, this loss is used for noise related tasks (default: False).
- is_dereverb_loss (bool , optional) – If True, this loss is used for dereverberation tasks (default: False).
- Returns: The computed SNR loss, with shape (Batch,).
- Return type: torch.Tensor
####### Examples
>>> snr_loss = SNRLoss()
>>> reference = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
>>> estimate = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]])
>>> loss = snr_loss(reference, estimate)
>>> print(loss)
tensor([-3.0103, -3.0103]) # Example output, will vary based on input.
- Raises:AssertionError – If the shapes of the reference and estimated tensors do not match.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(ref: Tensor, inf: Tensor) → Tensor
Signal-to-Noise Ratio (SNR) Loss.
This class computes the Signal-to-Noise Ratio (SNR) loss, which is defined as the difference in decibels between the signal and the noise. The SNR is a critical metric in evaluating the quality of audio signals, especially in speech enhancement tasks.
eps
A small value to prevent division by zero.
Type: float
Parameters:
- eps (float) – A small value to avoid numerical instability (default: machine epsilon).
- name (str , optional) – Name of the loss function (default: “snr_loss”).
- only_for_test (bool , optional) – Flag indicating if the loss is only for testing (default: False).
- is_noise_loss (bool , optional) – Flag indicating if this loss is for noise estimation (default: False).
- is_dereverb_loss (bool , optional) – Flag indicating if this loss is for dereverberation (default: False).
Raises:ValueError – If both is_noise_loss and is_dereverb_loss are True.
####### Examples
>>> import torch
>>> snr_loss = SNRLoss()
>>> reference = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
>>> estimated = torch.tensor([[1.0, 2.5], [2.5, 4.5]])
>>> loss = snr_loss(reference, estimated)
>>> print(loss)
tensor([-2.0000, -2.0000])
Notes
The returned loss value is negative, representing the loss to be minimized during training.