espnet2.enh.layers.skim.SegLSTM
espnet2.enh.layers.skim.SegLSTM
class espnet2.enh.layers.skim.SegLSTM(input_size, hidden_size, dropout=0.0, bidirectional=False, norm_type='cLN')
Bases: Module
The Seg-LSTM of SkiM.
This class implements the Segmented Long Short-Term Memory (Seg-LSTM) model as part of the SkiM architecture for low-latency real-time continuous speech separation. It processes input features in segments and maintains hidden states for the LSTM.
Parameters:
- input_size (int) – Dimension of the input feature. The input should have shape (batch, seq_len, input_size).
- hidden_size (int) – Dimension of the hidden state.
- dropout (float , optional) – Dropout ratio. Default is 0.
- bidirectional (bool , optional) – Whether the LSTM layers are bidirectional. Default is False.
- norm_type (str , optional) – Normalization type, either ‘gLN’ or ‘cLN’. ‘cLN’ is for causal implementation.
Returns: The processed output of shape : (batch, seq_len, input_size).
(h, c) (tuple): The hidden and cell states of the LSTM.
Return type: output (torch.Tensor)
####### Examples
>>> import torch
>>> model = SegLSTM(input_size=16, hidden_size=32)
>>> input_tensor = torch.randn(4, 10, 16) # (batch_size, seq_len, input_size)
>>> output, (h, c) = model(input_tensor, None)
>>> print(output.shape) # Should print: torch.Size([4, 10, 16])
NOTE
In the first input to the SkiM block, the hidden (h) and cell (c) states are initialized to zero.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input, hc)
Performs a forward pass of the SegLSTM model.
This method takes the hidden and cell states from the previous SegLSTM layer and processes them through the MemLSTM, applying necessary transformations based on the specified memory type.
- Parameters:
- hc (tuple) – A tuple containing hidden and cell states from the previous layer. Each of shape (d, B*S, H), where:
- d: number of directions (1 or 2 for bidirectional)
- B: batch size
- S: number of segments
- H: hidden state size
- S (int) – Number of segments in the SegLSTM.
- hc (tuple) – A tuple containing hidden and cell states from the previous layer. Each of shape (d, B*S, H), where:
- Returns: A tuple containing the updated hidden and cell states. : The shapes depend on the mem_type configuration and may vary as follows:
- If mem_type is ‘id’, returns the input hc.
- Otherwise, returns transformed hidden and cell states of
shape (B*S, d, H).
- Return type: tuple
- Raises:AssertionError – If the mem_type is not one of the allowed values.
####### Examples
>>> model = SegLSTM(input_size=16, hidden_size=11)
>>> hc = (torch.zeros(1, 3, 11), torch.zeros(1, 3, 11)) # Example states
>>> S = 2 # Number of segments
>>> output = model.forward(hc, S)
NOTE
This method handles both bidirectional and unidirectional configurations. If the model is configured for causal processing, the output states will be adjusted accordingly.