espnet2.enh.layers.skim.SkiM
espnet2.enh.layers.skim.SkiM
class espnet2.enh.layers.skim.SkiM(input_size, hidden_size, output_size, dropout=0.0, num_blocks=2, segment_size=20, bidirectional=True, mem_type='hc', norm_type='gLN', seg_overlap=False)
Bases: Module
Skipping Memory Net (SkiM) for low-latency real-time continuous speech separation.
This class implements the SkiM model as described in the paper: “SkiM: Skipping Memory LSTM for Low-Latency Real-Time Continuous Speech Separation” (https://arxiv.org/abs/2201.10800).
input_size
Dimension of the input feature.
- Type: int
hidden_size
Dimension of the hidden state.
- Type: int
output_size
Dimension of the output size.
- Type: int
dropout
Dropout ratio. Default is 0.
- Type: float
num_blocks
Number of basic SkiM blocks.
- Type: int
segment_size
Segmentation size for splitting long features.
- Type: int
bidirectional
Whether the RNN layers are bidirectional.
- Type: bool
mem_type
Controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. Options are ‘hc’, ‘h’, ‘c’, ‘id’, or None. In ‘id’ mode, both hidden and cell states will be identically returned. When mem_type is None, the MemLSTM will be removed.
- Type: str or None
norm_type
Normalization type; can be ‘gLN’ or ‘cLN’. cLN is for causal implementation.
- Type: str
seg_overlap
Whether the segmentation will reserve 50% overlap for adjacent segments. Default is False.
Type: bool
Parameters:
- input_size (int) – Dimension of the input feature.
- hidden_size (int) – Dimension of the hidden state.
- output_size (int) – Dimension of the output size.
- dropout (float) – Dropout ratio. Default is 0.
- num_blocks (int) – Number of basic SkiM blocks.
- segment_size (int) – Segmentation size for splitting long features.
- bidirectional (bool) – Whether the RNN layers are bidirectional.
- mem_type (str or None) – Controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM.
- norm_type (str) – Normalization type; can be ‘gLN’ or ‘cLN’.
- seg_overlap (bool) – Whether to reserve 50% overlap for adjacent segments.
######### Examples
>>> model = SkiM(
... input_size=16,
... hidden_size=11,
... output_size=16,
... dropout=0.0,
... num_blocks=4,
... segment_size=20,
... bidirectional=False,
... mem_type="hc",
... norm_type="cLN",
... seg_overlap=False,
... )
>>> input_tensor = torch.randn(3, 100, 16)
>>> output = model(input_tensor)
####### NOTE This implementation is designed for continuous speech separation tasks with a focus on low-latency processing.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input)
Forward pass of the MemLSTM module.
This method takes the hidden and cell states from a SegLSTM and processes them through the MemLSTM layer, returning the updated hidden and cell states. The method supports various memory types to control how the hidden and cell states are processed.
- Parameters:
- hc (tuple) – A tuple containing the hidden state (h) and cell state (c) from the SegLSTM. Both should have the shape (d, B*S, H), where d is the number of directions, B is the batch size, S is the number of segments, and H is the hidden size.
- S (int) – The number of segments in the SegLSTM.
- Returns: A tuple containing the updated hidden state (h) and cell state (c) after processing through the MemLSTM. If the memory type is “id”, the original states are returned without modification.
- Return type: tuple
####### NOTE For the causal setup (non-bidirectional), the method adjusts the hidden and cell states to ensure causality.
######### Examples
>>> mem_lstm = MemLSTM(hidden_size=64, mem_type='hc')
>>> h = torch.randn(2, 3, 64) # Example hidden state
>>> c = torch.randn(2, 3, 64) # Example cell state
>>> hc = (h, c)
>>> updated_hc = mem_lstm.forward(hc, S=3)
forward_stream(input_frame, states)
Process a single frame of input in a streaming manner.
This method updates the internal state of the SkiM model based on the provided input frame and the current states. It allows the model to handle streaming inputs efficiently by maintaining memory across segments.
Parameters:
- input_frame (torch.Tensor) – The input frame of shape (B, 1, N), where B is the batch size and N is the feature dimension.
- states (dict) – A dictionary containing the current states of the model, including the current step and segment states.
Returns: The processed output of shape (B, 1, D), : where D is the output feature dimension.
states (dict): The updated states dictionary, which includes the : current step and updated segment states.
Return type: output (torch.Tensor)
####### NOTE The states dictionary is initialized if not provided. It keeps track of the current step and segment-level states for each block.
######### Examples
>>> model = SkiM(input_size=16, hidden_size=11, output_size=16)
>>> input_frame = torch.randn(3, 1, 16) # Batch of 3, 1 time step
>>> states = {}
>>> output, states = model.forward_stream(input_frame, states)