espnet2.enh.loss.criterions.tf_domain.FrequencyDomainL1
espnet2.enh.loss.criterions.tf_domain.FrequencyDomainL1
class espnet2.enh.loss.criterions.tf_domain.FrequencyDomainL1(compute_on_mask=False, mask_type='IBM', name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: FrequencyDomainLoss
Computes the time-frequency L1 loss for audio signal enhancement.
This class implements the L1 loss function for frequency-domain audio enhancement. The L1 loss is computed between the reference signal and the estimated signal, which can be either the mask or the spectrum, depending on the configuration.
compute_on_mask
Indicates whether the loss is computed on the mask or the spectrum.
- Type: bool
mask_type
The type of mask used for loss computation.
Type: str
Parameters:
- compute_on_mask (bool) – If True, computes the loss on the mask; otherwise, computes it on the spectrum.
- mask_type (str) – The type of mask to be used. Defaults to “IBM”.
- name (str) – An optional name for the loss instance.
- only_for_test (bool) – If True, the loss is only used during testing.
- is_noise_loss (bool) – If True, the loss is related to noise.
- is_dereverb_loss (bool) – If True, the loss is related to dereverberation.
Raises:ValueError – If both is_noise_loss and is_dereverb_loss are True.
####### Examples
>>> import torch
>>> loss_fn = FrequencyDomainL1(compute_on_mask=True, mask_type="IBM")
>>> ref = torch.randn(4, 100, 256) # Example reference signal
>>> inf = torch.randn(4, 100, 256) # Example estimated signal
>>> loss = loss_fn(ref, inf)
>>> print(loss) # Outputs the computed L1 loss for the batch
NOTE
The input tensors ref and inf should have the same shape, which can either be (Batch, T, F) or (Batch, T, C, F).
Initialize internal Module state, shared by both nn.Module and ScriptModule.
property compute_on_mask : bool
forward(ref, inf) → Tensor
L1 loss in the frequency domain for enhancement tasks.
This class computes the time-frequency L1 loss between reference and estimated signals in the frequency domain. It can operate on either the mask or the spectrum, depending on the configuration.
compute_on_mask
Indicates whether to compute loss on the mask.
- Type: bool
mask_type
The type of mask to be used for loss calculation.
Type: str
Parameters:
- compute_on_mask (bool , optional) – If True, compute loss on the mask. Defaults to False.
- mask_type (str , optional) – Type of mask to use (e.g., “IBM”). Defaults to “IBM”.
- name (str , optional) – Name of the loss function. Defaults to None.
- only_for_test (bool , optional) – If True, this loss is only used for testing. Defaults to False.
- is_noise_loss (bool , optional) – If True, this loss is used for noise-related calculations. Defaults to False.
- is_dereverb_loss (bool , optional) – If True, this loss is used for dereverberation-related calculations. Defaults to False.
Returns: The computed L1 loss.
Return type: torch.Tensor
Raises:
- ValueError – If the input shapes of ref and inf do not match or
- if the input shape is invalid. –
####### Examples
>>> import torch
>>> loss_fn = FrequencyDomainL1()
>>> ref = torch.randn(10, 2, 4) # (Batch, T, F)
>>> inf = torch.randn(10, 2, 4) # (Batch, T, F)
>>> loss = loss_fn(ref, inf)
>>> print(loss.shape) # (Batch,)
NOTE
The input tensors ref and inf should have the same shape. The implementation handles both complex and real tensors.
property mask_type : str