espnet2.enh.loss.criterions.tf_domain.FrequencyDomainMSE
espnet2.enh.loss.criterions.tf_domain.FrequencyDomainMSE
class espnet2.enh.loss.criterions.tf_domain.FrequencyDomainMSE(compute_on_mask=False, mask_type='IBM', name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: FrequencyDomainLoss
FrequencyDomainMSE computes the Mean Squared Error (MSE) loss in the frequency
domain for speech enhancement tasks. This loss can be computed either on the mask or directly on the spectrum.
compute_on_mask
Indicates whether to compute the loss on the mask.
- Type: bool
mask_type
The type of mask to be used in the loss calculation.
Type: str
Parameters:
- compute_on_mask (bool) – If True, the loss is computed on the mask; otherwise, it is computed on the spectrum.
- mask_type (str) – The type of mask to use (default is “IBM”).
- name (str , optional) – Name of the loss instance.
- only_for_test (bool) – Indicates if the loss is only for testing (default False).
- is_noise_loss (bool) – Indicates if this loss is for noise (default False).
- is_dereverb_loss (bool) – Indicates if this loss is for dereverberation (default False).
Raises:ValueError – If is_noise_loss and is_dereverb_loss are both True.
####### Examples
>>> loss = FrequencyDomainMSE(compute_on_mask=True, mask_type="IRM")
>>> ref = torch.randn(2, 100, 256) # Reference tensor
>>> inf = torch.randn(2, 100, 256) # Inferred tensor
>>> loss_value = loss(ref, inf) # Compute the MSE loss
>>> loss = FrequencyDomainMSE()
>>> ref = torch.randn(2, 100, 1, 256) # Reference tensor with channels
>>> inf = torch.randn(2, 100, 1, 256) # Inferred tensor with channels
>>> loss_value = loss(ref, inf) # Compute the MSE loss
- Returns: The computed MSE loss for each batch.
- Return type: torch.Tensor
NOTE
The input tensors ref and inf must have the same shape.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
property compute_on_mask : bool
forward(ref, inf) → Tensor
Compute the time-frequency Mean Squared Error (MSE) loss.
This method calculates the MSE loss between the reference and the inferred signals in the frequency domain. The loss is computed separately for real and imaginary parts if the inputs are complex.
- Parameters:
- ref – A tensor of shape (Batch, T, F) or (Batch, T, C, F) representing the reference signal.
- inf – A tensor of shape (Batch, T, F) or (Batch, T, C, F) representing the inferred signal.
- Returns: A tensor of shape (Batch,) containing the computed MSE loss for each element in the batch.
- Raises:ValueError – If the shapes of ref and inf do not match, or if the dimensions of ref are not 3 or 4.
####### Examples
>>> import torch
>>> loss_fn = FrequencyDomainMSE()
>>> ref = torch.rand(8, 100, 256) # Example reference signal
>>> inf = torch.rand(8, 100, 256) # Example inferred signal
>>> loss = loss_fn.forward(ref, inf)
>>> print(loss.shape) # Output: torch.Size([8])
property mask_type : str