espnet2.enh.loss.criterions.time_domain.SISNRLoss
espnet2.enh.loss.criterions.time_domain.SISNRLoss
class espnet2.enh.loss.criterions.time_domain.SISNRLoss(clamp_db=None, zero_mean=True, eps=None, name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: TimeDomainLoss
SI-SNR (or SI-SDR) loss.
A more stable SI-SNR loss with clamp from fast_bss_eval.
clamp_db
float Clamp the output value in [-clamp_db, clamp_db].
zero_mean
bool When set to True, the mean of all signals is subtracted prior.
eps
float Deprecated. Kept for compatibility.
- Parameters:
- clamp_db – (float, optional) Clamp the output value in [-clamp_db, clamp_db]. Default is None.
- zero_mean – (bool, optional) When set to True, the mean of all signals is subtracted prior. Default is True.
- eps – (float, optional) Deprecated parameter for compatibility. Default is None.
- name – (str, optional) Name of the loss function. Default is “si_snr_loss”.
- only_for_test – (bool, optional) If True, the loss is only used during testing. Default is False.
- is_noise_loss – (bool, optional) If True, the loss is related to noise. Default is False.
- is_dereverb_loss – (bool, optional) If True, the loss is related to dereverberation. Default is False.
- Returns: (torch.Tensor) : The SI-SDR loss (negative SI-SDR).
- Return type: loss
####### Examples
>>> loss_fn = SISNRLoss(clamp_db=10)
>>> reference = torch.randn(8, 16000) # (Batch, samples)
>>> estimated = torch.randn(8, 16000) # (Batch, samples)
>>> loss = loss_fn(reference, estimated)
>>> print(loss.shape) # Should output: torch.Size([8])
NOTE
The eps parameter is deprecated and will be removed in future versions. It is recommended to use clamp_db instead.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(ref: Tensor, est: Tensor) → Tensor
SI-SNR (or named SI-SDR) loss.
A more stable SI-SNR loss with clamp from fast_bss_eval.
clamp_db
float Clamp the output value in [-clamp_db, clamp_db].
zero_mean
bool When set to True, the mean of all signals is subtracted prior.
eps
float Deprecated. Kept for compatibility.
- Parameters:
- clamp_db – (float, optional) Clamp the output value in [-clamp_db, clamp_db]. Default is None.
- zero_mean – (bool, optional) When set to True, the mean of all signals is subtracted prior. Default is True.
- eps – (float, optional) Deprecated. Kept for compatibility. Default is None.
- name – (str, optional) Name of the loss function. Default is “si_snr_loss”.
- only_for_test – (bool, optional) Flag to indicate if the loss is only for testing. Default is False.
- is_noise_loss – (bool, optional) Flag to indicate if the loss is related to noise. Default is False.
- is_dereverb_loss – (bool, optional) Flag to indicate if the loss is related to dereverberation. Default is False.
- Returns: (torch.Tensor) : The SI-SDR loss (negative SI-SDR).
- Return type: loss
####### Examples
>>> loss_function = SISNRLoss(clamp_db=10.0)
>>> reference_signal = torch.randn(1, 16000)
>>> estimated_signal = torch.randn(1, 16000)
>>> loss = loss_function(reference_signal, estimated_signal)
>>> print(loss)
NOTE
The parameter eps is deprecated; it is recommended to use clamp_db instead for stability.