espnet2.enh.layers.bsrnn.BSRNN
espnet2.enh.layers.bsrnn.BSRNN
class espnet2.enh.layers.bsrnn.BSRNN(input_dim=481, num_channel=16, num_layer=6, target_fs=48000, causal=True, num_spk=1, norm_type='GN')
Bases: Module
Band-Split RNN (BSRNN) for high fidelity speech enhancement.
This model implements a Band-Split RNN architecture for effective monaural speech enhancement. It leverages RNNs to model temporal and frequency features of audio signals, aiming to improve the quality of speech signals in noisy environments.
References
[1] J. Yu, H. Chen, Y. Luo, R. Gu, and C. Weng, “High fidelity speech enhancement with band-split RNN,” in Proc. ISCA Interspeech, 2023. https://isca-speech.org/archive/interspeech_2023/yu23b_interspeech.html [2] J. Yu, and Y. Luo, “Efficient monaural speech enhancement with universal sample rate band-split RNN,” in Proc. ICASSP, 2023. https://ieeexplore.ieee.org/document/10096020
- Parameters:
- input_dim (int) – Maximum number of frequency bins corresponding to target_fs.
- num_channel (int) – Embedding dimension of each time-frequency bin.
- num_layer (int) – Number of time and frequency RNN layers.
- target_fs (int) – Maximum sampling frequency supported by the model.
- causal (bool) – Whether to adopt causal processing. If True, LSTM will be used instead of BLSTM for time modeling.
- num_spk (int) – Number of outputs to be generated.
- norm_type (str) – Type of normalization layer (cfLN / cLN / BN / GN).
- Returns: Output tensor of shape (B, num_spk, T, F, 2).
- Return type: out (torch.Tensor)
####### Examples
>>> model = BSRNN(input_dim=481, num_channel=16, num_layer=6)
>>> input_tensor = torch.randn(8, 100, 481, 2) # Batch size of 8
>>> output = model(input_tensor)
>>> print(output.shape)
torch.Size([8, 1, 100, 481, 2]) # Assuming num_spk=1
NOTE
The input tensor shape is expected to be (B, T, F, 2) where:
- B is the batch size
- T is the time dimension
- F is the frequency dimension
- 2 represents the real and imaginary parts of the complex signal.
- Raises:ValueError – If an unsupported normalization type is provided.
Band-Split RNN (BSRNN).
References
[1] J. Yu, H. Chen, Y. Luo, R. Gu, and C. Weng, “High fidelity speech enhancement with band-split RNN,” in Proc. ISCA Interspeech, 2023. https://isca-speech.org/archive/interspeech_2023/yu23b_interspeech.html [2] J. Yu, and Y. Luo, “Efficient monaural speech enhancement with universal sample rate band-split RNN,” in Proc. ICASSP, 2023. https://ieeexplore.ieee.org/document/10096020
- Parameters:
- input_dim (int) – maximum number of frequency bins corresponding to target_fs
- num_channel (int) – embedding dimension of each time-frequency bin
- num_layer (int) – number of time and frequency RNN layers
- target_fs (int) – maximum sampling frequency supported by the model
- causal (bool) – Whether or not to adopt causal processing if True, LSTM will be used instead of BLSTM for time modeling
- num_spk (int) – number of outputs to be generated
- norm_type (str) – type of normalization layer (cfLN / cLN / BN / GN)
forward(x, fs=None)
BSRNN forward pass.
This method performs the forward pass of the Band-Split RNN (BSRNN), processing the input tensor to produce an output tensor. The input tensor is assumed to have a specific shape and can be optionally truncated based on the sampling rate provided.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (B, T, F, 2), where B is the batch size, T is the time dimension, F is the frequency dimension, and 2 represents the real and imaginary parts of the complex signal.
- fs (int , optional) – Sampling rate of the input signal. If not None, the input signal will be truncated to only process the effective frequency subbands. If None, the input signal is assumed to be already truncated to only contain effective frequency subbands.
- Returns: Output tensor of shape (B, num_spk, T, F, 2), : where num_spk is the number of speakers to be generated.
- Return type: out (torch.Tensor)
####### Examples
>>> model = BSRNN()
>>> input_tensor = torch.randn(8, 100, 481, 2) # Example input
>>> output = model(input_tensor, fs=48000)
>>> print(output.shape)
torch.Size([8, 1, 100, 481, 2]) # Example output shape
NOTE
The input tensor should be formatted correctly to ensure proper processing. The forward pass involves multiple layers of normalization, RNN processing, and a mask decoding step to generate the output.
- Raises:ValueError – If the input tensor shape does not match the expected dimensions.