espnet2.enh.loss.criterions.time_domain.TimeDomainMSE
espnet2.enh.loss.criterions.time_domain.TimeDomainMSE
class espnet2.enh.loss.criterions.time_domain.TimeDomainMSE(name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: TimeDomainLoss
Time-domain Mean Squared Error (MSE) loss.
This loss computes the mean squared error between the reference signal and the estimated signal in the time domain. It can handle inputs with different dimensions, either 2D (Batch, T) or 3D (Batch, T, C).
- Parameters:
- name (str , optional) – The name of the loss function. Default is “TD_MSE_loss”.
- only_for_test (bool , optional) – Flag to indicate if the loss is only used for testing. Default is False.
- is_noise_loss (bool , optional) – Flag to indicate if this loss is related to noise. Default is False.
- is_dereverb_loss (bool , optional) – Flag to indicate if this loss is related to dereverberation. Default is False.
- Raises:ValueError – If the shapes of the reference and estimated signals do not match.
- Returns: The computed loss, shape (Batch,).
- Return type: torch.Tensor
####### Examples
>>> import torch
>>> loss_fn = TimeDomainMSE()
>>> ref = torch.randn(4, 16000) # 4 samples, 16000 time steps
>>> inf = torch.randn(4, 16000)
>>> loss = loss_fn(ref, inf)
>>> print(loss.shape) # Output: torch.Size([4])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(ref, inf) → Tensor
Time-domain MSE loss forward.
Computes the Mean Squared Error (MSE) loss between the reference signal and the estimated signal in the time domain. The MSE is calculated as the average of the squared differences between the two signals.
- Parameters:
- ref – (Batch, T) or (Batch, T, C) The reference signal(s) for comparison.
- inf – (Batch, T) or (Batch, T, C) The estimated signal(s) to be evaluated against the reference.
- Returns: (Batch,) : The computed MSE loss for each item in the batch.
- Return type: loss
- Raises:ValueError – If the shapes of ref and inf do not match.
####### Examples
>>> import torch
>>> ref = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
>>> inf = torch.tensor([[1.5, 2.5], [3.5, 4.5]])
>>> loss = TimeDomainMSE().forward(ref, inf)
>>> print(loss) # Output: tensor([0.25, 0.25])