espnet2.enh.layers.beamformer.get_sdw_mwf_vector
espnet2.enh.layers.beamformer.get_sdw_mwf_vector
espnet2.enh.layers.beamformer.get_sdw_mwf_vector(psd_speech, psd_noise, reference_vector: Tensor | int, denoising_weight: float = 1.0, approx_low_rank_psd_speech: bool = False, iterations: int = 3, diagonal_loading: bool = True, diag_eps: float = 1e-07, eps: float = 1e-08)
Return the SDW-MWF (Speech Distortion Weighted Multi-channel Wiener Filter) vector.
The formula for the SDW-MWF is given by: : h = (Spsd + mu * Npsd)^-1 @ Spsd @ u
This filter emphasizes the preservation of speech while reducing noise.
References
[1] Spatially pre-processed speech distortion weighted multi-channel Wiener filtering for noise reduction; A. Spriet et al., 2004 https://dl.acm.org/doi/abs/10.1016/j.sigpro.2004.07.028 [2] Rank-1 constrained multichannel Wiener filter for speech recognition in noisy environments; Z. Wang et al., 2018 https://hal.inria.fr/hal-01634449/document [3] Low-rank approximation based multichannel Wiener filter algorithms for noise reduction with application in cochlear implants; R. Serizel, 2014 https://ieeexplore.ieee.org/document/6730918
- Parameters:
- psd_speech (torch.complex64/ComplexTensor) – Speech covariance matrix with shape (…, F, C, C).
- psd_noise (torch.complex64/ComplexTensor) – Noise covariance matrix with shape (…, F, C, C).
- reference_vector (torch.Tensor or int) – Reference vector with shape (…, C) or scalar.
- denoising_weight (float) – Trade-off parameter between noise reduction and speech distortion. A larger value leads to more noise reduction at the expense of more speech distortion. The plain MWF is obtained with denoising_weight = 1 (default).
- approx_low_rank_psd_speech (bool) – Whether to replace original input psd_speech with its low-rank approximation as in [2].
- iterations (int) – Number of iterations in power method, only used when approx_low_rank_psd_speech = True.
- diagonal_loading (bool) – Whether to add a tiny term to the diagonal of psd_n.
- diag_eps (float) – Regularization factor for diagonal loading.
- eps (float) – Small constant to prevent division by zero.
- Returns: The computed beamforming vector with shape (…, F, C).
- Return type: beamform_vector (torch.complex64/ComplexTensor)
Examples
>>> psd_s = torch.randn(1, 8, 2, 2, dtype=torch.complex64)
>>> psd_n = torch.randn(1, 8, 2, 2, dtype=torch.complex64)
>>> ref_vector = torch.tensor([1, 0], dtype=torch.complex64)
>>> vector = get_sdw_mwf_vector(psd_s, psd_n, ref_vector)
>>> print(vector.shape)
torch.Size([1, 8, 2])