espnet2.enh.layers.wpe.wpe
Less than 1 minute
espnet2.enh.layers.wpe.wpe
espnet2.enh.layers.wpe.wpe(Y: Tensor | ComplexTensor, taps=10, delay=3, iterations=3) → Tensor | ComplexTensor
WPE (Weighted Prediction Error) for enhancing complex-valued STFT signals.
This implementation is a PyTorch version of WPE, originally ported from https://github.com/fgnt/nara_wpe. The algorithm processes the input signal to enhance its quality by applying a series of filter operations based on weighted correlations.
espnet2.enh.layers.wpe.is_torch_1_9_plus
Indicates if the PyTorch version is 1.9 or above.
Type: bool
Parameters:
- Y (Union *[*torch.Tensor , ComplexTensor ]) – Complex-valued STFT signal with shape (F, C, T).
- taps (int) – Number of filter taps for the WPE algorithm. Default is 10.
- delay (int) – Delay as a guard interval to prevent the signal from becoming zero. Default is 3.
- iterations (int) – Number of iterations to perform WPE. Default is 3.
Returns: Enhanced signal with shape (F, C, T) after applying WPE.
Return type: Union[torch.Tensor, ComplexTensor]
Examples
>>> import torch
>>> from espnet2.enh.layers.wpe import wpe
>>> Y = torch.randn(64, 2, 100) # Example STFT signal
>>> enhanced_signal = wpe(Y, taps=10, delay=3, iterations=3)
>>> print(enhanced_signal.shape)
torch.Size([64, 2, 100]) # Enhanced signal shape is the same as input
NOTE
This function assumes that the input signal is properly formatted and contains complex values.