espnet2.enh.layers.mask_estimator.MaskEstimator
espnet2.enh.layers.mask_estimator.MaskEstimator
class espnet2.enh.layers.mask_estimator.MaskEstimator(type, idim, layers, units, projs, dropout, nmask=1, nonlinear='sigmoid')
Bases: Module
MaskEstimator is a neural network module for estimating masks in audio
enhancement tasks. It leverages recurrent neural networks (RNN) for temporal feature extraction and produces multiple masks based on the input spectrogram.
type
The type of RNN architecture to be used (‘vgg’ or ‘p’).
- Type: str
nmask
The number of masks to estimate.
- Type: int
nonlinear
The type of nonlinearity applied to the output masks (‘sigmoid’, ‘relu’, ‘tanh’, or ‘crelu’).
- Type: str
brnn
The recurrent neural network module.
- Type: torch.nn.Module
linears
A list of linear layers for mask estimation.
Type: torch.nn.ModuleList
Parameters:
- type (str) – The type of RNN architecture (‘vgg’, ‘vggp’, etc.).
- idim (int) – Input dimension (number of features).
- layers (int) – Number of RNN layers.
- units (int) – Number of units in each RNN layer.
- projs (int) – Number of projected features after the RNN.
- dropout (float) – Dropout rate for the RNN.
- nmask (int , optional) – Number of masks to estimate (default is 1).
- nonlinear (str , optional) – Nonlinearity to apply to the output masks (default is ‘sigmoid’).
Returns: A tuple containing the estimated masks and the input lengths.
Return type: Tuple[Tuple[torch.Tensor, …], torch.LongTensor]
Raises:ValueError – If the specified nonlinear activation is not supported.
Examples
>>> mask_estimator = MaskEstimator(type='vgg', idim=64, layers=3,
... units=128, projs=64, dropout=0.1)
>>> xs = torch.randn(8, 64, 2, 100) # Example input (B, F, C, T)
>>> ilens = torch.tensor([100] * 8) # Example input lengths
>>> masks, lengths = mask_estimator(xs, ilens)
>>> print(masks) # Output masks for each estimated mask
NOTE
The input tensor xs should have dimensions (B, F, C, T), where B is the batch size, F is the number of frequency bins, C is the number of channels, and T is the number of time frames.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(xs: Tensor | ComplexTensor, ilens: LongTensor) → Tuple[Tuple[Tensor, ...], LongTensor]
Mask estimator forward function.
- Parameters:
- xs – (B, F, C, T)
- ilens – (B,)
- Returns: The hidden vector (B, F, C, T) masks: A tuple of the masks. (B, F, C, T) ilens: (B,)
- Return type: hs (torch.Tensor)