espnet2.enh.layers.beamformer.get_mwf_vector
Less than 1 minute
espnet2.enh.layers.beamformer.get_mwf_vector
espnet2.enh.layers.beamformer.get_mwf_vector(psd_s, psd_n, reference_vector: Tensor | int, diagonal_loading: bool = True, diag_eps: float = 1e-07, eps: float = 1e-08)
Return the MWF (Minimum Multi-channel Wiener Filter) vector.
The MWF vector is calculated using the formula:
h = (Npsd^-1 @ Spsd) @ u
where: : - Npsd is the noise covariance matrix (psd_n).
- Spsd is the speech covariance matrix (psd_s).
- u is the reference vector.
- Parameters:
- psd_s (torch.complex64/ComplexTensor) – Speech covariance matrix with shape (…, F, C, C).
- psd_n (torch.complex64/ComplexTensor) – Power-normalized observation covariance matrix with shape (…, F, C, C).
- reference_vector (torch.Tensor or int) – Reference vector with shape (…, C) or a scalar index.
- diagonal_loading (bool) – Whether to add a tiny term to the diagonal of psd_n to avoid singularities.
- diag_eps (float) – Regularization term added to the diagonal if diagonal_loading is True.
- eps (float) – Small constant to prevent division by zero.
- Returns: The calculated MWF vector with shape (…, F, C).
- Return type: beamform_vector (torch.complex64/ComplexTensor)
Examples
>>> psd_s = torch.rand(2, 4, 3, 3, dtype=torch.complex64)
>>> psd_n = torch.rand(2, 4, 3, 3, dtype=torch.complex64)
>>> reference_vector = torch.tensor([1.0, 0.0, 0.0])
>>> mwf_vector = get_mwf_vector(psd_s, psd_n, reference_vector)
NOTE
- The function assumes that psd_s and psd_n are complex tensors.
- The reference_vector can be provided either as a tensor or an index.