espnet2.enh.loss.criterions.tf_domain.FrequencyDomainAbsCoherence
espnet2.enh.loss.criterions.tf_domain.FrequencyDomainAbsCoherence
class espnet2.enh.loss.criterions.tf_domain.FrequencyDomainAbsCoherence(compute_on_mask=False, mask_type=None, name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: FrequencyDomainLoss
Computes the absolute coherence loss in the frequency domain.
This loss is used to measure the coherence between the reference and inferred complex tensors in the frequency domain, which is useful in tasks such as source separation and enhancement.
Reference: : Independent Vector Analysis with Deep Neural Network Source Priors; Li et al 2020; https://arxiv.org/abs/2008.11273
compute_on_mask
Indicates whether the loss is computed on the mask. Always returns False for this class.
- Type: bool
mask_type
The type of mask. Always returns None for this class.
Type: str
Parameters:
- compute_on_mask (bool) – Whether to compute the loss on the mask. Default is False.
- mask_type (str) – Type of mask used for the loss computation. Default is None.
- name (str) – Optional name for the loss instance.
- only_for_test (bool) – If True, indicates the loss is only for testing. Default is False.
- is_noise_loss (bool) – If True, indicates this loss is related to noise. Default is False.
- is_dereverb_loss (bool) – If True, indicates this loss is related to dereverberation. Default is False.
Returns: The computed loss for the batch, with shape (Batch,).
Return type: torch.Tensor
Raises:
- ValueError – If the shapes of ref and inf do not match or if
- they are not complex tensors. –
####### Examples
>>> loss_fn = FrequencyDomainAbsCoherence()
>>> ref = torch.randn(8, 128, 64, dtype=torch.complex64)
>>> inf = torch.randn(8, 128, 64, dtype=torch.complex64)
>>> loss = loss_fn(ref, inf)
>>> print(loss.shape) # Output: torch.Size([8])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
property compute_on_mask : bool
forward(ref, inf) → Tensor
Computes the time-frequency absolute coherence loss.
This loss is designed to measure the absolute coherence between the reference and the inferred spectrograms. It is particularly useful in the context of source separation tasks.
Reference: : Independent Vector Analysis with Deep Neural Network Source Priors; Li et al 2020; https://arxiv.org/abs/2008.11273
compute_on_mask
Indicates if the computation is performed on the mask.
- Type: bool
mask_type
The type of mask being used (not applicable for this class).
Type: str
Parameters:
- compute_on_mask (bool , optional) – Flag indicating if the loss is computed on the mask. Defaults to False.
- mask_type (str , optional) – The type of mask to be used. Defaults to None.
- name (str , optional) – Name of the loss. Defaults to “Coherence_on_Spec”.
- only_for_test (bool , optional) – Indicates if the loss is only for testing. Defaults to False.
- is_noise_loss (bool , optional) – Indicates if the loss is related to noise. Defaults to False.
- is_dereverb_loss (bool , optional) – Indicates if the loss is related to dereverberation. Defaults to False.
Returns: The computed loss of shape (Batch,).
Return type: torch.Tensor
Raises:ValueError – If the input tensors do not have the correct dimensions or if they are not complex tensors.
####### Examples
>>> ref = torch.randn(32, 100, 256, dtype=torch.complex64)
>>> inf = torch.randn(32, 100, 256, dtype=torch.complex64)
>>> loss_fn = FrequencyDomainAbsCoherence()
>>> loss = loss_fn(ref, inf)
property mask_type : str