espnet2.enh.loss.criterions.tf_domain.FrequencyDomainDPCL
espnet2.enh.loss.criterions.tf_domain.FrequencyDomainDPCL
class espnet2.enh.loss.criterions.tf_domain.FrequencyDomainDPCL(compute_on_mask=False, mask_type='IBM', loss_type='dpcl', name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: FrequencyDomainLoss
FrequencyDomainDPCL is a class that implements the time-frequency Deep Clustering
loss for audio signal enhancement. This loss function is designed to encourage the discriminative separation of audio sources in the frequency domain.
This class inherits from the FrequencyDomainLoss base class and provides the necessary methods and properties to compute the Deep Clustering loss.
compute_on_mask
Indicates whether the loss is computed on the mask.
- Type: bool
mask_type
The type of mask used for loss computation.
- Type: str
loss_type
Specifies the type of loss, can be “dpcl” or “mdc”.
Type: str
Parameters:
- compute_on_mask (bool) – Flag to indicate if loss is computed on the mask.
- mask_type (str) – Type of mask to be used (default is “IBM”).
- loss_type (str) – Type of loss to be used (“dpcl” or “mdc”, default is “dpcl”).
- name (str) – Optional name for the loss instance.
- only_for_test (bool) – Indicates if the loss is only for testing (default is False).
- is_noise_loss (bool) – Indicates if the loss is for noise-related tasks (default is False).
- is_dereverb_loss (bool) – Indicates if the loss is for dereverberation tasks (default is False).
Returns: A tensor representing the computed loss for the batch.
Return type: torch.Tensor
Raises:ValueError – If an invalid loss type is provided.
Examples
>>> loss = FrequencyDomainDPCL(compute_on_mask=True, mask_type="IRM")
>>> ref = [torch.rand(2, 100, 256) for _ in range(3)] # Simulated references
>>> inf = torch.rand(2, 100 * 256, 3) # Simulated predictions
>>> output = loss(ref, inf)
>>> print(output.shape) # Should output: torch.Size([2])
References
[1] Deep clustering: Discriminative embeddings for segmentation and : separation; John R. Hershey. et al., 2016; https://ieeexplore.ieee.org/document/7471631
[2] Manifold-Aware Deep Clustering: Maximizing Angles Between Embedding : Vectors Based on Regular Simplex; Tanaka, K. et al., 2021; https://www.isca-speech.org/archive/interspeech_2021/tanaka21_interspeech.html
Initialize internal Module state, shared by both nn.Module and ScriptModule.
property compute_on_mask : bool
forward(ref, inf) → Tensor
Time-frequency Deep Clustering loss.
This class implements the Deep Clustering loss used for speech separation tasks. The loss can be calculated based on either the output embeddings or the reference signals. It is designed to work with both “dpcl” and “mdc” loss types.
References
[1] Deep clustering: Discriminative embeddings for segmentation and : separation; John R. Hershey et al., 2016; https://ieeexplore.ieee.org/document/7471631
[2] Manifold-Aware Deep Clustering: Maximizing Angles Between Embedding : Vectors Based on Regular Simplex; Tanaka, K. et al., 2021; https://www.isca-speech.org/archive/interspeech_2021/tanaka21_interspeech.html
- Parameters:
- compute_on_mask – If True, compute the loss on the mask instead of the spectrum.
- mask_type – The type of mask to be used (default: “IBM”).
- loss_type – The type of loss to compute (“dpcl” or “mdc”).
- name – Optional name for the loss instance.
- only_for_test – If True, the loss is only used during testing.
- is_noise_loss – If True, this loss is specifically for noise-related calculations.
- is_dereverb_loss – If True, this loss is specifically for dereverberation-related calculations.
- Returns: A tensor containing the computed loss for each batch.
- Return type: loss
property mask_type : str