espnet2.enh.layers.wpe.wpe_one_iteration
Less than 1 minute
espnet2.enh.layers.wpe.wpe_one_iteration
espnet2.enh.layers.wpe.wpe_one_iteration(Y: Tensor | ComplexTensor, power: Tensor, taps: int = 10, delay: int = 3, eps: float = 1e-10, inverse_power: bool = True) → Tensor | ComplexTensor
WPE for one iteration.
This function performs one iteration of the Weighted Prediction Error (WPE) algorithm on a complex-valued Short-Time Fourier Transform (STFT) signal. It enhances the input signal by applying a filter based on the calculated correlations of the signal.
- Parameters:
- Y – Complex-valued STFT signal with shape (…, C, T).
- power – Power of the signal with shape (…, T).
- taps – Number of filter taps (default: 10).
- delay – Delay as a guard interval to prevent X from becoming zero (default: 3).
- eps – Small value to prevent division by zero in the inverse power calculation (default: 1e-10).
- inverse_power (bool) – If True, uses the inverse of the power; otherwise, uses the power itself (default: True).
- Returns: Enhanced signal with shape (…, C, T).
- Return type: enhanced
- Raises:
- ValueError – If the input tensor Y does not match the shape of
- power –
Examples
>>> import torch
>>> Y = torch.randn(1, 2, 100, dtype=torch.complex64) # Shape (B, C, T)
>>> power = torch.abs(Y)**2 # Calculate power
>>> enhanced_signal = wpe_one_iteration(Y, power)