espnet2.enh.loss.criterions.time_domain.CISDRLoss
espnet2.enh.loss.criterions.time_domain.CISDRLoss
class espnet2.enh.loss.criterions.time_domain.CISDRLoss(filter_length=512, name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: TimeDomainLoss
CI-SDR loss.
This class implements the Convolutive Transfer Function Invariant SDR Training Criteria for Multi-Channel Reverberant Speech Separation. It computes the CI-SDR loss between reference and estimated signals, allowing for a slight distortion via a time-invariant filter.
Reference: : Convolutive Transfer Function Invariant SDR Training Criteria for Multi-Channel Reverberant Speech Separation; C. Boeddeker et al., 2021; https://arxiv.org/abs/2011.15003
- Parameters:
- filter_length (int) – A time-invariant filter length that allows slight distortion via filtering (default: 512).
- name (str , optional) – Name of the loss function. If None, defaults to “ci_sdr_loss”.
- only_for_test (bool , optional) – Indicates if the loss is only for testing (default: False).
- is_noise_loss (bool , optional) – Indicates if this is a noise-related loss (default: False).
- is_dereverb_loss (bool , optional) – Indicates if this is a dereverberation-related loss (default: False).
- Returns: The computed loss of shape (Batch,).
- Return type: torch.Tensor
- Raises:
- ValueError – If both is_noise_loss and is_dereverb_loss are
- True. –
####### Examples
>>> import torch
>>> loss_fn = CISDRLoss(filter_length=256)
>>> reference = torch.randn(10, 16000) # 10 samples, 16000 time steps
>>> estimated = torch.randn(10, 16000)
>>> loss = loss_fn(reference, estimated)
>>> print(loss.shape) # Output: torch.Size([10])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(ref: Tensor, inf: Tensor) → Tensor
Compute the CI-SDR loss between reference and estimated signals.
This method calculates the CI-SDR (Convolutive Transfer Function Invariant Signal-to-Distortion Ratio) loss, which is a metric used for evaluating the quality of audio signals after separation or enhancement.
- Parameters:
- ref (torch.Tensor) – The reference signal tensor with shape (Batch, samples).
- inf (torch.Tensor) – The estimated signal tensor with shape (Batch, samples).
- Returns: The computed CI-SDR loss tensor with shape (Batch,).
- Return type: torch.Tensor
- Raises:AssertionError – If the shapes of ref and inf do not match.
####### Examples
>>> import torch
>>> loss_fn = CISDRLoss()
>>> ref = torch.randn(8, 16000) # Batch of 8 signals, each 16000 samples
>>> inf = torch.randn(8, 16000) # Estimated signals
>>> loss = loss_fn.forward(ref, inf)
>>> print(loss.shape) # Output: torch.Size([8])