espnet2.enh.layers.dc_crn.DC_CRN
espnet2.enh.layers.dc_crn.DC_CRN
class espnet2.enh.layers.dc_crn.DC_CRN(input_dim, input_channels: List = [2, 16, 32, 64, 128, 256], enc_hid_channels=8, enc_kernel_size=(1, 3), enc_padding=(0, 1), enc_last_kernel_size=(1, 4), enc_last_stride=(1, 2), enc_last_padding=(0, 1), enc_layers=5, skip_last_kernel_size=(1, 3), skip_last_stride=(1, 1), skip_last_padding=(0, 1), glstm_groups=2, glstm_layers=2, glstm_bidirectional=False, glstm_rearrange=False, output_channels=2)
Bases: Module
Densely-Connected Convolutional Recurrent Network (DC-CRN).
This class implements a Densely-Connected Convolutional Recurrent Network as described in the paper by Tan et al. [1]. The network architecture includes multiple densely connected blocks, grouped LSTM layers, and skip pathways to enhance the speech signals.
Reference: : Tan et al. “Deep Learning Based Real-Time Speech Enhancement for Dual-Microphone Mobile Phones”. https://web.cse.ohio-state.edu/~wang.77/papers/TZW.taslp21.pdf
Args: : input_dim (int): Input feature dimension. input_channels (list): Number of input channels for the stacked <br/>
DenselyConnectedBlock layers. Its length should be equal to the number of DenselyConnectedBlock layers. It is recommended to use an even number of channels to avoid AssertError when glstm_bidirectional=True. <br/> enc_hid_channels (int): Common number of intermediate channels : for all DenselyConnectedBlock of the encoder. <br/> enc_kernel_size (tuple): Common kernel size for all : DenselyConnectedBlock of the encoder. <br/> enc_padding (tuple): Common padding for all : DenselyConnectedBlock of the encoder. <br/> enc_last_kernel_size (tuple): Common kernel size for the last : Conv layer in all DenselyConnectedBlock of the encoder. <br/> enc_last_stride (tuple): Common stride for the last Conv layer : in all DenselyConnectedBlock of the encoder. <br/> enc_last_padding (tuple): Common padding for the last Conv layer : in all DenselyConnectedBlock of the encoder. <br/> enc_layers (int): Common total number of Conv layers for all : DenselyConnectedBlock layers of the encoder. <br/> skip_last_kernel_size (tuple): Common kernel size for the last : Conv layer in all DenselyConnectedBlock of the skip pathways. <br/> skip_last_stride (tuple): Common stride for the last Conv layer : in all DenselyConnectedBlock of the skip pathways. <br/> skip_last_padding (tuple): Common padding for the last Conv : layer in all DenselyConnectedBlock of the skip pathways. <br/> glstm_groups (int): Number of groups in each Grouped LSTM layer. glstm_layers (int): Number of Grouped LSTM layers. glstm_bidirectional (bool): Whether to use BLSTM or unidirectional <br/> LSTM in Grouped LSTM layers. <br/> glstm_rearrange (bool): Whether to apply the rearrange operation : after each grouped LSTM layer. <br/> output_channels (int): Number of output channels (must be an even : number to recover both real and imaginary parts).
Raises: : AssertionError: If the number of output channels is not even.
Examples: : ```python
model = DC_CRN(input_dim=128, output_channels=2) input_tensor = torch.randn(1, 2, 42, 128) output = model(input_tensor) print(output.shape) # Output shape: (1, 2, output_channels, 42, 128)
Note: : Ensure that input_channels is set appropriately based on the architecture requirements.
[1]: Tan et al. “Deep Learning Based Real-Time Speech Enhancement for Dual-Microphone Mobile Phones”. https://web.cse.ohio-state.edu/~wang.77/papers/TZW.taslp21.pdf
Densely-Connected Convolutional Recurrent Network (DC-CRN).
Reference: Fig. 3 and Section III-B in [1]
- Parameters:
- input_dim (int) – input feature dimension
- input_channels (list) – number of input channels for the stacked DenselyConnectedBlock layers Its length should be (number of DenselyConnectedBlock layers). It is recommended to use even number of channels to avoid AssertError when glstm_bidirectional=True.
- enc_hid_channels (int) – common number of intermediate channels for all DenselyConnectedBlock of the encoder
- enc_kernel_size (tuple) – common kernel size for all DenselyConnectedBlock of the encoder
- enc_padding (tuple) – common padding for all DenselyConnectedBlock of the encoder
- enc_last_kernel_size (tuple) – common kernel size for the last Conv layer in all DenselyConnectedBlock of the encoder
- enc_last_stride (tuple) – common stride for the last Conv layer in all DenselyConnectedBlock of the encoder
- enc_last_padding (tuple) – common padding for the last Conv layer in all DenselyConnectedBlock of the encoder
- enc_layers (int) – common total number of Conv layers for all DenselyConnectedBlock layers of the encoder
- skip_last_kernel_size (tuple) – common kernel size for the last Conv layer in all DenselyConnectedBlock of the skip pathways
- skip_last_stride (tuple) – common stride for the last Conv layer in all DenselyConnectedBlock of the skip pathways
- skip_last_padding (tuple) – common padding for the last Conv layer in all DenselyConnectedBlock of the skip pathways
- glstm_groups (int) – number of groups in each Grouped LSTM layer
- glstm_layers (int) – number of Grouped LSTM layers
- glstm_bidirectional (bool) – whether to use BLSTM or unidirectional LSTM in Grouped LSTM layers
- glstm_rearrange (bool) – whether to apply the rearrange operation after each grouped LSTM layer
- output_channels (int) – number of output channels (must be an even number to recover both real and imaginary parts)
forward(x)
DC-CRN forward.
This method defines the forward pass of the Densely-Connected Convolutional Recurrent Network (DC-CRN). It processes the input tensor through several convolutional layers followed by a grouped LSTM layer and then through deconvolutional layers to produce the final output. The input should consist of concatenated real and imaginary spectrum features.
- Parameters:x (torch.Tensor) – Concatenated real and imaginary spectrum features of shape (B, input_channels[0], T, F), where:
- B is the batch size
- input_channels[0] is the number of input channels
- T is the temporal dimension
- F is the frequency dimension
- Returns: Output tensor of shape (B, 2, output_channels, T, F), : where:
- 2 corresponds to the real and imaginary parts
- output_channels is the number of output channels
- Return type: out (torch.Tensor)
Examples
>>> model = DC_CRN(input_dim=256)
>>> input_tensor = torch.randn(8, 2, 42, 127) # Example input
>>> output = model(input_tensor)
>>> print(output.shape) # Output shape: (8, 2, output_channels, 42, 127)
NOTE
The number of output channels must be an even number to recover both real and imaginary parts.