espnet2.enh.loss.criterions.time_domain.TimeDomainL1
espnet2.enh.loss.criterions.time_domain.TimeDomainL1
class espnet2.enh.loss.criterions.time_domain.TimeDomainL1(name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: TimeDomainLoss
Time-domain L1 loss.
This loss function computes the L1 loss (Mean Absolute Error) between the reference and estimated signals in the time domain. It is often used in tasks such as speech enhancement where preserving the structure of the waveform is essential.
- Parameters:
- name (str , optional) – The name of the loss function. Defaults to “TD_L1_loss”.
- only_for_test (bool , optional) – If True, the loss is only used for testing. Defaults to False.
- is_noise_loss (bool , optional) – If True, indicates that this loss is specifically for noise-related tasks. Defaults to False.
- is_dereverb_loss (bool , optional) – If True, indicates that this loss is specifically for dereverberation tasks. Defaults to False.
- Returns: The computed L1 loss with shape (Batch,).
- Return type: torch.Tensor
- Raises:ValueError – If the shapes of the reference and estimated signals do not match.
####### Examples
>>> import torch
>>> loss_fn = TimeDomainL1()
>>> ref = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> inf = torch.tensor([[1.0, 2.5, 3.0], [3.5, 5.0, 6.0]])
>>> loss = loss_fn(ref, inf)
>>> print(loss)
tensor([0.1667, 0.1667])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(ref, inf) → Tensor
Time-domain L1 loss forward.
This method computes the L1 loss between the reference and estimated signals in the time domain. The L1 loss is defined as the mean absolute difference between the two signals.
- Parameters:
- ref – (Batch, T) or (Batch, T, C) The reference signal tensor, which can be either 2D or 3D.
- inf – (Batch, T) or (Batch, T, C) The estimated signal tensor, which should match the shape of ref.
- Returns: (Batch,) : The computed L1 loss for each sample in the batch.
- Return type: loss
- Raises:ValueError – If the shapes of ref and inf do not match.
####### Examples
>>> import torch
>>> loss_fn = TimeDomainL1()
>>> ref = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> inf = torch.tensor([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]])
>>> loss = loss_fn(ref, inf)
>>> print(loss)
tensor([0.5000, 0.5000])
NOTE
The method ensures that the input tensors are of compatible shapes before proceeding with the loss computation.