espnet2.enh.loss.criterions.time_domain.MultiResL1SpecLoss
espnet2.enh.loss.criterions.time_domain.MultiResL1SpecLoss
class espnet2.enh.loss.criterions.time_domain.MultiResL1SpecLoss(window_sz=[512], hop_sz=None, eps=1e-08, time_domain_weight=0.5, normalize_variance=False, reduction='sum', name=None, only_for_test=False, is_noise_loss=False, is_dereverb_loss=False)
Bases: TimeDomainLoss
Multi-Resolution L1 time-domain + STFT magnitude loss.
This loss function combines the L1 loss in the time domain with the short-time Fourier transform (STFT) magnitude loss. It aims to improve the quality of speech enhancement by leveraging both time-domain and frequency-domain information.
Reference: : Lu, Y. J., Cornell, S., Chang, X., Zhang, W., Li, C., Ni, Z., … & Watanabe, S. Towards Low-Distortion Multi-Channel Speech Enhancement: The ESPNET-Se Submission to the L3DAS22 Challenge. ICASSP 2022 p. 9201-9205.
window_sz
A list of STFT window sizes.
- Type: list
hop_sz
A list of hop sizes, default is each window_sz // 2.
- Type: list, optional
eps
Stability epsilon to prevent division by zero.
- Type: float
time_domain_weight
Weight for time domain loss.
- Type: float
normalize_variance
Whether to normalize the variance when calculating the loss.
- Type: bool
reduction
Method for reducing the loss, select from “sum” and “mean”.
Type: str
Parameters:
- window_sz (list) – List of STFT window sizes.
- hop_sz (list , optional) – List of hop sizes, default is each window_sz // 2.
- eps (float , optional) – Stability epsilon (default: 1e-8).
- time_domain_weight (float , optional) – Weight for time domain loss (default: 0.5).
- normalize_variance (bool , optional) – Normalize variance when calculating loss (default: False).
- reduction (str , optional) – Reduction method, either “sum” or “mean” (default: “sum”).
- name (str , optional) – Name of the loss function (default: None).
- only_for_test (bool , optional) – If True, only used for testing (default: False).
- is_noise_loss (bool , optional) – If True, the loss is related to noise (default: False).
- is_dereverb_loss (bool , optional) – If True, the loss is related to dereverberation (default: False).
Returns: The computed loss value with shape (Batch,).
Return type: torch.Tensor
######### Examples
>>> loss_fn = MultiResL1SpecLoss(window_sz=[256, 512])
>>> target = torch.randn(10, 16000) # (Batch, T)
>>> estimate = torch.randn(10, 16000) # (Batch, T)
>>> loss = loss_fn(target, estimate)
>>> print(loss.shape) # Output: torch.Size([10])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(target: Tensor, estimate: Tensor)
Multi-Resolution L1 time-domain + STFT magnitude loss.
This loss combines time-domain L1 loss with the STFT magnitude loss to achieve low-distortion multi-channel speech enhancement. It effectively leverages both time-domain and frequency-domain information to improve the quality of speech signals.
Reference: : Lu, Y. J., Cornell, S., Chang, X., Zhang, W., Li, C., Ni, Z., … & Watanabe, S. Towards Low-Distortion Multi-Channel Speech Enhancement: The ESPNET-Se Submission to the L3DAS22 Challenge. ICASSP 2022 p. 9201-9205.
window_sz
A list of STFT window sizes.
- Type: list
hop_sz
A list of hop sizes; defaults to each window_sz // 2.
- Type: list, optional
eps
Stability epsilon to avoid division by zero.
- Type: float
time_domain_weight
Weight for time-domain loss.
- Type: float
normalize_variance
Whether to normalize the variance when calculating the loss.
- Type: bool
reduction
Select from “sum” and “mean” for loss reduction.
Type: str
Parameters:
- window_sz (list) – List of STFT window sizes.
- hop_sz (list , optional) – List of hop sizes; defaults to each window_sz // 2.
- eps (float , optional) – Stability epsilon; defaults to 1e-8.
- time_domain_weight (float , optional) – Weight for time-domain loss; defaults to 0.5.
- normalize_variance (bool , optional) – Whether to normalize variance; defaults to False.
- reduction (str , optional) – Method of reduction; defaults to “sum”.
- name (str , optional) – Name of the loss; defaults to None.
- only_for_test (bool , optional) – Flag for test-only mode; defaults to False.
- is_noise_loss (bool , optional) – Flag indicating if it’s a noise loss; defaults to False.
- is_dereverb_loss (bool , optional) – Flag indicating if it’s a dereverberation loss; defaults to False.
Returns: The computed loss with shape (Batch,).
Return type: torch.Tensor
######### Examples
>>> loss_fn = MultiResL1SpecLoss(window_sz=[256, 512],
... hop_sz=[128, 256])
>>> target = torch.randn(8, 512)
>>> estimate = torch.randn(8, 512)
>>> loss = loss_fn(target, estimate)
>>> print(loss.shape)
torch.Size([8])
get_magnitude(stft, eps=1e-06)
Multi-Resolution L1 time-domain + STFT magnitude loss.
This loss function combines the time-domain L1 loss with the Short-Time Fourier Transform (STFT) magnitude loss to enhance speech signals by minimizing distortion. It is particularly useful for multi-channel speech enhancement tasks.
Reference: : Lu, Y. J., Cornell, S., Chang, X., Zhang, W., Li, C., Ni, Z., … & Watanabe, S. Towards Low-Distortion Multi-Channel Speech Enhancement: The ESPNET-Se Submission to the L3DAS22 Challenge. ICASSP 2022 p. 9201-9205.
window_sz
(list) List of STFT window sizes.
hop_sz
(list, optional) List of hop sizes, default is each window_sz // 2.
eps
(float) Stability epsilon.
time_domain_weight
(float) Weight for time-domain loss.
normalize_variance
(bool) Whether or not to normalize the variance when calculating the loss.
reduction
(str) Select from “sum” and “mean”.
- Parameters:
- window_sz – List of integers representing window sizes for STFT.
- hop_sz – Optional list of integers representing hop sizes.
- eps – Float for numerical stability.
- time_domain_weight – Float for weighting time-domain loss.
- normalize_variance – Boolean to normalize variance.
- reduction – String to specify reduction method, either “sum” or “mean”.
- name – Optional name for the loss instance.
- only_for_test – Boolean to indicate if the loss is for testing only.
- is_noise_loss – Boolean to indicate if the loss is for noise-related tasks.
- is_dereverb_loss – Boolean to indicate if the loss is for dereverberation tasks.
get_magnitude(stft, eps=1e-06)
Computes the magnitude of the STFT.
######### Examples
>>> loss = MultiResL1SpecLoss(window_sz=[256, 512], time_domain_weight=0.7)
>>> target = torch.randn(10, 512)
>>> estimate = torch.randn(10, 512)
>>> output = loss(target, estimate)
- Raises:AssertionError – If the input dimensions do not match in the forward method.
property name : str