espnet2.enh.layers.bsrnn.MaskDecoder
espnet2.enh.layers.bsrnn.MaskDecoder
class espnet2.enh.layers.bsrnn.MaskDecoder(freq_dim, subbands, channels=128, num_spk=1, norm_type='GN')
Bases: Module
Mask Decoder for band-split RNN-based speech enhancement.
This class implements a mask decoder that processes input tensors and generates the corresponding output masks and residuals. The mask decoder is a crucial component in the BSRNN architecture, enabling the model to enhance speech signals by estimating the mask and residual signals.
subbands
The number of frequency subbands.
- Type: tuple
freq_dim
Total frequency dimension, should equal the sum of subbands.
- Type: int
num_spk
Number of speakers to generate outputs for.
- Type: int
mlp_mask
List of MLPs for generating masks for each subband.
- Type: nn.ModuleList
mlp_residual
List of MLPs for generating residuals for each subband.
Type: nn.ModuleList
Parameters:
- freq_dim (int) – Total frequency dimension.
- subbands (tuple) – Number of frequency subbands.
- channels (int) – Number of channels in the input tensor.
- num_spk (int) – Number of outputs to generate (default is 1).
- norm_type (str) – Type of normalization layer to be used (default is “GN”).
Returns: None
####### Examples
>>> decoder = MaskDecoder(freq_dim=481, subbands=(5, 4, 4, 4, 4),
... channels=128, num_spk=1)
>>> input_tensor = torch.randn(10, 16, 20, 5) # Example input
>>> masks, residuals = decoder(input_tensor)
>>> masks.shape # Should be (10, 1, 20, 481, 2)
>>> residuals.shape # Should be (10, 1, 20, 481, 2)
- Raises:AssertionError – If freq_dim does not equal the sum of subbands.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
BSRNN forward.
This method performs the forward pass of the Band-Split RNN (BSRNN) model, processing the input tensor through the band splitting, RNN layers, and mask decoding to produce the output tensor.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (B, T, F, 2), where B is the batch size, T is the time dimension, F is the frequency dimension, and 2 represents the real and imaginary parts of the complex input.
- fs (int , optional) – Sampling rate of the input signal. If provided, the input signal will be truncated to only process the effective frequency subbands. If None, the input signal is assumed to be already truncated to only contain effective frequency subbands.
- Returns: Output tensor of shape (B, num_spk, T, F, 2), where : num_spk is the number of speakers, T is the time dimension, F is the frequency dimension, and 2 represents the real and imaginary parts of the complex output.
- Return type: out (torch.Tensor)
####### Examples
>>> model = BSRNN(input_dim=481, num_spk=2)
>>> input_tensor = torch.randn(8, 100, 481, 2) # Example input
>>> output = model(input_tensor, fs=48000)
>>> print(output.shape) # Output shape will be (8, 2, 100, 481, 2)
NOTE
The input tensor must have the correct shape, and if the sampling rate (fs) is provided, it should be compatible with the target sampling rate of the model.