espnet2.enh.layers.dnn_wpe.DNN_WPE
espnet2.enh.layers.dnn_wpe.DNN_WPE
class espnet2.enh.layers.dnn_wpe.DNN_WPE(wtype: str = 'blstmp', widim: int = 257, wlayers: int = 3, wunits: int = 300, wprojs: int = 320, dropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask: bool = True, nmask: int = 1, nonlinear: str = 'sigmoid', iterations: int = 1, normalization: bool = False, eps: float = 1e-06, diagonal_loading: bool = True, diag_eps: float = 1e-07, mask_flooring: bool = False, flooring_thres: float = 1e-06, use_torch_solver: bool = True)
Bases: Module
DNN_WPE is a deep neural network-based implementation for Weighted
Prediction Error (WPE) dereverberation.
This module utilizes a DNN mask estimator to predict the mask for dereverberation and applies the WPE algorithm iteratively to enhance the input signal.
iterations
Number of iterations for WPE processing.
- Type: int
taps
Number of taps used in the WPE algorithm.
- Type: int
delay
Delay parameter for the WPE algorithm.
- Type: int
eps
Small value to prevent division by zero.
- Type: float
normalization
Whether to normalize the masks.
- Type: bool
use_dnn_mask
Flag to indicate if DNN mask estimation should be used.
- Type: bool
diagonal_loading
Flag to indicate if diagonal loading is used.
- Type: bool
diag_eps
Small value for diagonal loading.
- Type: float
mask_flooring
Flag to indicate if mask flooring is applied.
- Type: bool
flooring_thres
Threshold for mask flooring.
- Type: float
use_torch_solver
Flag to indicate if PyTorch solver is used.
Type: bool
Parameters:
- wtype (str) – Type of the network (default: “blstmp”).
- widim (int) – Dimension of the input features (default: 257).
- wlayers (int) – Number of layers in the DNN (default: 3).
- wunits (int) – Number of units in each layer (default: 300).
- wprojs (int) – Number of projections (default: 320).
- dropout_rate (float) – Dropout rate for the DNN (default: 0.0).
- taps (int) – Number of taps for WPE (default: 5).
- delay (int) – Delay parameter for WPE (default: 3).
- use_dnn_mask (bool) – Whether to use DNN mask estimation (default: True).
- nmask (int) – Number of masks to be predicted (default: 1).
- nonlinear (str) – Nonlinearity type for DNN (default: “sigmoid”).
- iterations (int) – Number of iterations for WPE (default: 1).
- normalization (bool) – Whether to normalize the masks (default: False).
- eps (float) – Small value for numerical stability (default: 1e-6).
- diagonal_loading (bool) – Whether to apply diagonal loading (default: True).
- diag_eps (float) – Small value for diagonal loading (default: 1e-7).
- mask_flooring (bool) – Whether to apply mask flooring (default: False).
- flooring_thres (float) – Threshold for mask flooring (default: 1e-6).
- use_torch_solver (bool) – Whether to use PyTorch solver (default: True).
Returns: Tuple[Union[torch.Tensor, ComplexTensor], torch.LongTensor, Union[torch.Tensor, ComplexTensor]]:
- enhanced: The enhanced signal (shape: (B, T, C, F)).
- ilens: Input lengths (shape: (B,)).
- masks: Predicted masks (shape: (B, T, C, F)).
- power: Calculated power (shape: (B, F, T)).
######### Examples
>>> model = DNN_WPE()
>>> input_data = torch.randn(2, 100, 1, 257) # (B, T, C, F)
>>> ilens = torch.tensor([100, 100])
>>> enhanced, ilens, masks, power = model(input_data, ilens)
####### NOTE This class requires PyTorch and torch_complex libraries.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(data: Tensor | ComplexTensor, ilens: LongTensor) → Tuple[Tensor | ComplexTensor, LongTensor, Tensor | ComplexTensor]
DNN_WPE forward function.
This method performs the forward pass for the DNN_WPE model, which applies deep neural network-based weighted prediction error (WPE) to enhance input audio signals. The input can be either a standard PyTorch tensor or a complex tensor, and the function returns the enhanced audio, input lengths, and the calculated masks.
Notation: : B: Batch C: Channel T: Time or Sequence length F: Frequency or Some dimension of the feature vector
Parameters:
- data – Input audio data of shape (B, T, C, F), where B is the batch size, T is the time length, C is the number of channels, and F is the feature dimension.
- ilens – Input lengths of shape (B,) that indicate the valid time steps for each batch element.
Returns: Enhanced audio data of : shape (B, T, C, F).
ilens (torch.LongTensor): Input lengths of shape (B,) for the enhanced : output.
masks (torch.Tensor or List[torch.Tensor]): Masks used in the enhancement : process of shape (B, T, C, F).
power (List[torch.Tensor]): Power estimates of shape (B, F, T).
Return type: enhanced (torch.Tensor or List[torch.Tensor])
######### Examples
>>> model = DNN_WPE()
>>> input_data = torch.randn(2, 100, 1, 257) # Example input
>>> input_lengths = torch.tensor([100, 80]) # Example lengths
>>> enhanced_output, lengths, masks, power = model(input_data, input_lengths)
####### NOTE The method performs several iterations to refine the enhanced output, and can apply different configurations for mask estimation, normalization, and flooring based on the model’s parameters.
- Raises:ValueError – If the input data shape does not match the expected dimensions.
predict_mask(data: Tensor | ComplexTensor, ilens: LongTensor) → Tuple[Tensor, LongTensor]
Predict mask for WPE dereverberation.
This method computes the masks used in the Weighted Prediction Error (WPE) dereverberation process. It utilizes a deep neural network (DNN) to estimate the masks from the input data.
- Parameters:
- data (torch.complex64/ComplexTensor) – Input tensor of shape (B, T, C, F), where B is the batch size, T is the time length, C is the number of channels, and F is the frequency dimension. The input should be in double precision.
- ilens (torch.Tensor) – A tensor of shape (B,) representing the lengths of each input sequence in the batch.
- Returns: A tuple containing: : - masks (torch.Tensor or List[torch.Tensor]): The predicted masks of shape (B, T, C, F) after transposing from (B, F, C, T).
- ilens (torch.Tensor): The unchanged lengths of each input sequence in the batch, of shape (B,).
- Return type: Tuple[torch.Tensor, torch.LongTensor]
######### Examples
>>> model = DNN_WPE()
>>> data = torch.randn(10, 100, 2, 257, dtype=torch.complex64)
>>> ilens = torch.tensor([100] * 10)
>>> masks, ilens = model.predict_mask(data, ilens)
>>> print(masks.shape) # Output: torch.Size([10, 100, 2, 257])
####### NOTE This method is only available if use_dnn_mask is set to True during the initialization of the DNN_WPE model.