espnet2.enh.layers.dc_crn.GLSTM
espnet2.enh.layers.dc_crn.GLSTM
class espnet2.enh.layers.dc_crn.GLSTM(hidden_size=1024, groups=2, layers=2, bidirectional=False, rearrange=False)
Bases: Module
Grouped LSTM.
This class implements a Grouped LSTM (GLSTM) that utilizes multiple LSTM layers organized in groups to improve the learning capacity for sequence data. The architecture is inspired by the work of Gao et al. (2018), which introduced efficient sequence learning with group recurrent networks.
Reference: : Efficient Sequence Learning with Group Recurrent Networks; Gao et al., 2018
Args: : hidden_size (int): Total hidden size of all LSTMs in each grouped : LSTM layer, i.e., hidden size of each LSTM is hidden_size // groups. <br/> groups (int): Number of LSTMs in each grouped LSTM layer. layers (int): Number of grouped LSTM layers. bidirectional (bool): Whether to use bidirectional LSTM (BLSTM) or <br/>
unidirectional LSTM. <br/> rearrange (bool): Whether to apply the rearrange operation after each : grouped LSTM layer.
Raises: : AssertionError: If hidden_size is not divisible by groups, or if : hidden_size_t (hidden size per LSTM) is odd when bidirectional=True.
Examples: : ```python
glstm = GLSTM(hidden_size=1024, groups=2, layers=2, ... bidirectional=True, rearrange=True) input_tensor = torch.randn(16, 1024, 10, 1) # (B, C, T, D) output = glstm(input_tensor) output.shape torch.Size([16, 1024, 10, 1]) # Output shape matches input shape
Grouped LSTM.
Reference: : Efficient Sequence Learning with Group Recurrent Networks; Gao et al., 2018
- Parameters:
- hidden_size (int) – total hidden size of all LSTMs in each grouped LSTM layer i.e., hidden size of each LSTM is hidden_size // groups
- groups (int) – number of LSTMs in each grouped LSTM layer
- layers (int) – number of grouped LSTM layers
- bidirectional (bool) – whether to use BLSTM or unidirectional LSTM
- rearrange (bool) – whether to apply the rearrange operation after each grouped LSTM layer
forward(x)
DC-CRN forward.
This method performs the forward pass of the Densely-Connected Convolutional Recurrent Network (DC-CRN). It processes the input tensor through several layers of convolutional blocks, a grouped LSTM, and then through a series of transposed convolutional blocks to produce the output.
- Parameters:x (torch.Tensor) – Concatenated real and imaginary spectrum features with shape (B, input_channels[0], T, F), where:
- B: Batch size
- T: Temporal dimension
- F: Frequency dimension
- Returns: The output tensor with shape (B, 2, output_channels, T, F), : where:
- The first dimension represents the batch size.
- The second dimension represents the real and imaginary parts.
- The third dimension corresponds to the output channels.
- The fourth and fifth dimensions represent the temporal and frequency
dimensions, respectively.
- Return type: out (torch.Tensor)
Examples
>>> model = DC_CRN(input_dim=256)
>>> input_tensor = torch.randn(16, 2, 42, 127) # Example input
>>> output = model(input_tensor)
>>> print(output.shape)
torch.Size([16, 2, output_channels, 42, 127]) # Expected output shape
NOTE
The input tensor should be concatenated with real and imaginary parts of the spectrum features, and the output is also structured in the same manner.