espnet2.enh.loss.criterions.time_domain.SDRLoss
espnet2.enh.loss.criterions.time_domain.SDRLoss
class espnet2.enh.loss.criterions.time_domain.SDRLoss(filter_length=512, use_cg_iter=None, clamp_db=None, zero_mean=True, load_diag=None, name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: TimeDomainLoss
SDR loss.
This class computes the Signal-to-Distortion Ratio (SDR) loss, which is commonly used in speech enhancement tasks. The SDR loss is useful for measuring the quality of an estimated signal compared to a reference signal, focusing on minimizing the distortion introduced by the enhancement process.
filter_length
The length of the distortion filter allowed (default: 512
).
- Type: int
use_cg_iter
If provided, an iterative method is used to solve for the distortion filter coefficients instead of direct Gaussian elimination. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.
- Type: int or None
clamp_db
Clamp the output value in [-clamp_db, clamp_db].
- Type: float or None
zero_mean
When set to True, the mean of all signals is subtracted prior to calculation.
- Type: bool
load_diag
If provided, this small value is added to the diagonal coefficients of the system matrices when solving for the filter coefficients. This can help stabilize the metric in the case where some of the reference signals may sometimes be zero.
Type: float or None
Parameters:
- filter_length (int , optional) – Length of the distortion filter (default: 512).
- use_cg_iter (int or None , optional) – Iterative method for solving filter coefficients (default: None).
- clamp_db (float or None , optional) – Clamping value for output (default: None).
- zero_mean (bool , optional) – Whether to zero the mean (default: True).
- load_diag (float or None , optional) – Small value for diagonal stabilization (default: None).
- name (str , optional) – Name of the loss function (default: None).
- only_for_test (bool , optional) – If the loss is only for testing (default: False).
- is_noise_loss (bool , optional) – If the loss is noise-related (default: False).
- is_dereverb_loss (bool , optional) – If the loss is related to dereverberation (default: False).
Returns: The computed SDR loss (negative SDR).
Return type: torch.Tensor
####### Examples
>>> import torch
>>> sdr_loss = SDRLoss()
>>> reference = torch.randn(2, 512) # Batch of 2 signals
>>> estimated = torch.randn(2, 512) # Estimated signals
>>> loss = sdr_loss(reference, estimated)
>>> print(loss) # Output will be the SDR loss value
NOTE
Ensure that the reference and estimated tensors have the same shape when passing them to the forward method.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(ref: Tensor, est: Tensor) → Tensor
Calculate the SDR loss.
This method computes the SDR (Signal-to-Distortion Ratio) loss between the reference signal and the estimated signal. The SDR loss is calculated as the negative SDR value.
- Parameters:
- ref – Tensor of shape (…, n_samples) The reference signal.
- est – Tensor of shape (…, n_samples) The estimated signal.
- Returns: Tensor of shape (…) : The SDR loss (negative SDR).
- Return type: loss
####### Examples
>>> import torch
>>> sdr_loss = SDRLoss()
>>> ref = torch.randn(2, 1000) # Example reference signals
>>> est = torch.randn(2, 1000) # Example estimated signals
>>> loss = sdr_loss(ref, est)
>>> print(loss.shape) # Output: torch.Size([2])