espnet2.enh.loss.criterions.tf_domain.FrequencyDomainCrossEntropy
espnet2.enh.loss.criterions.tf_domain.FrequencyDomainCrossEntropy
class espnet2.enh.loss.criterions.tf_domain.FrequencyDomainCrossEntropy(compute_on_mask=False, mask_type=None, ignore_id=-100, name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: FrequencyDomainLoss
FrequencyDomainCrossEntropy computes the cross-entropy loss in the frequency
domain for audio signal processing tasks.
This loss is used to evaluate the difference between the predicted and reference distributions in a frequency domain representation, which is useful in applications such as speech enhancement and separation.
compute_on_mask
Indicates whether the loss is computed on the mask.
- Type: bool
mask_type
Type of the mask being used for computation.
- Type: str
ignore_id
ID to ignore in the cross-entropy calculation.
Type: int
Parameters:
- compute_on_mask (bool , optional) – If True, the loss is computed on the mask. Defaults to False.
- mask_type (str , optional) – Type of mask. Defaults to None.
- ignore_id (int , optional) – ID to ignore in the loss computation. Defaults to -100.
- name (str , optional) – Name of the loss instance. Defaults to None.
- only_for_test (bool , optional) – If True, the loss is only for testing. Defaults to False.
- is_noise_loss (bool , optional) – If True, the loss is related to noise. Defaults to False.
- is_dereverb_loss (bool , optional) – If True, the loss is related to dereverberation. Defaults to False.
Returns: Computed loss for each batch.
Return type: torch.Tensor
####### Examples
>>> loss_fn = FrequencyDomainCrossEntropy()
>>> ref = torch.tensor([[0, 1, 2], [1, 2, -100]]) # Reference labels
>>> inf = torch.rand(2, 3, 4) # Predicted logits (Batch, T, nclass)
>>> loss = loss_fn(ref, inf)
>>> print(loss)
- Raises:ValueError – If the input shapes of ref and inf do not match or are invalid.
NOTE
- The input ref should be of shape (Batch, T) or (Batch, T, C) where T is the time dimension.
- The input inf should be of shape (Batch, T, nclass) or (Batch, T, C, nclass).
- The ignore_id can be used to exclude certain labels from the loss computation.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
property compute_on_mask : bool
forward(ref, inf) → Tensor
Computes the time-frequency cross-entropy loss for audio enhancement tasks.
This loss function is designed to compare a reference tensor with an inference tensor, both of which represent audio data in a time-frequency domain. The cross-entropy loss is useful in tasks such as speech separation and enhancement, where the goal is to classify or predict the presence of certain audio signals.
compute_on_mask
Indicates whether the loss is computed on the mask or the spectrum.
- Type: bool
mask_type
The type of mask being used (e.g., “IBM”, “IRM”).
- Type: str
ignore_id
The label ID to ignore during loss computation.
Type: int
Parameters:
- compute_on_mask (bool) – If True, the loss will be computed on the mask; otherwise, it will be computed on the spectrum.
- mask_type (str , optional) – The type of mask to use. Defaults to None.
- ignore_id (int , optional) – The label ID to ignore. Defaults to -100.
- name (str , optional) – The name of the loss function. Defaults to None.
- only_for_test (bool , optional) – If True, indicates the loss is only for testing. Defaults to False.
- is_noise_loss (bool , optional) – If True, indicates the loss is related to noise. Defaults to False.
- is_dereverb_loss (bool , optional) – If True, indicates the loss is related to dereverberation. Defaults to False.
Returns: The computed loss for the batch, shape (Batch,).
Return type: torch.Tensor
Raises:ValueError – If the shapes of ref and inf are not compatible or if the input dimensions are invalid.
####### Examples
>>> loss_fn = FrequencyDomainCrossEntropy()
>>> ref = torch.tensor([[1, 0], [0, 1]]) # Reference tensor
>>> inf = torch.tensor([[[0.1, 0.9], [0.8, 0.2]],
... [[0.7, 0.3], [0.2, 0.8]]]) # Inference tensor
>>> loss = loss_fn(ref, inf) # Compute loss
>>> print(loss) # Output the loss value
NOTE
Ensure that the input tensors ref and inf have compatible shapes as per the specifications mentioned in the Args section.
property mask_type : str