espnet2.enh.layers.wpe.get_correlations
espnet2.enh.layers.wpe.get_correlations
espnet2.enh.layers.wpe.get_correlations(Y: Tensor | ComplexTensor, inverse_power: Tensor, taps, delay) → Tuple[Tensor | ComplexTensor, Tensor | ComplexTensor]
Calculates weighted correlations of a window of length taps.
This function computes the correlation matrix and correlation vector for a given complex-valued Short-Time Fourier Transform (STFT) signal using a weighted approach based on the provided inverse power and correlation parameters.
espnet2.enh.layers.wpe.Y
Complex-valued STFT signal with shape (F, C, T).
- Type: Union[torch.Tensor, ComplexTensor]
espnet2.enh.layers.wpe.inverse_power
Weighting factor with shape (F, T).
- Type: torch.Tensor
espnet2.enh.layers.wpe.taps
Length of the correlation window.
- Type: int
espnet2.enh.layers.wpe.delay
Delay for the weighting factor.
Type: int
Parameters:
- Y – Union[torch.Tensor, ComplexTensor] Complex-valued STFT signal with shape (F, C, T).
- inverse_power – torch.Tensor Weighting factor with shape (F, T).
- taps – int Length of correlation window.
- delay – int Delay for the weighting factor.
Returns:
- Correlation matrix of shape (F, taps*C, taps*C).
- Correlation vector of shape (F, taps, C, C).
Return type: Tuple[Union[torch.Tensor, ComplexTensor], Union[torch.Tensor, ComplexTensor]]
Raises:
- AssertionError – If the dimensions of inverse_power do not match with
- the dimensions of Y. –
Examples
>>> Y = torch.randn(4, 2, 10, dtype=torch.complex64) # (F, C, T)
>>> inverse_power = torch.randn(4, 10) # (F, T)
>>> taps = 5
>>> delay = 2
>>> correlation_matrix, correlation_vector = get_correlations(Y,
... inverse_power, taps, delay)
>>> print(correlation_matrix.shape) # (4, 10, 10)
>>> print(correlation_vector.shape) # (4, 5, 2, 2)
NOTE
This function assumes that the input tensors are properly shaped and the operations will be performed in a batch manner.