espnet2.enh.layers.skim.MemLSTM
espnet2.enh.layers.skim.MemLSTM
class espnet2.enh.layers.skim.MemLSTM(hidden_size, dropout=0.0, bidirectional=False, mem_type='hc', norm_type='cLN')
Bases: Module
Memory LSTM (MemLSTM) for the SkiM model.
This class implements the MemLSTM layer used in the SkiM model described in “SkiM: Skipping Memory LSTM for Low-Latency Real-Time Continuous Speech Separation” (https://arxiv.org/abs/2201.10800).
hidden_size
Dimension of the hidden state.
- Type: int
dropout
Dropout ratio. Default is 0.
- Type: float
bidirectional
Whether the LSTM layers are bidirectional. Default is False.
- Type: bool
mem_type
Controls how the hidden (or cell) state of SegLSTM will be processed by MemLSTM. Options are ‘hc’, ‘h’, ‘c’, or ‘id’.
- Type: str
norm_type
Normalization type. Options are ‘gLN’ or ‘cLN’. ‘cLN’ is for causal implementation.
Type: str
Parameters:
- hidden_size (int) – Dimension of the hidden state.
- dropout (float , optional) – Dropout ratio. Default is 0.
- bidirectional (bool , optional) – Whether the LSTM layers are bidirectional. Default is False.
- mem_type (str , optional) – Controls how the hidden (or cell) state of SegLSTM will be processed by MemLSTM. Options are ‘hc’, ‘h’, ‘c’, or ‘id’. Default is ‘hc’.
- norm_type (str , optional) – Normalization type. Options are ‘gLN’ or ‘cLN’. Default is ‘cLN’.
Raises:AssertionError – If mem_type is not one of the supported types (‘hc’, ‘h’, ‘c’, ‘id’).
########### Examples
>>> mem_lstm = MemLSTM(hidden_size=128, dropout=0.1,
... bidirectional=True, mem_type='hc',
... norm_type='gLN')
>>> hc = (torch.randn(2, 32, 128), torch.randn(2, 32, 128)) # (h, c)
>>> output = mem_lstm(hc, S=4) # S is the number of segments
####### NOTE The forward method expects hc to be a tuple of hidden and cell states, and S to represent the number of segments.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
extra_repr() → str
Returns a string representation of the MemLSTM module’s configuration.
This method provides information about the memory type and whether the LSTM layers are bidirectional, which is useful for debugging and logging purposes.
mem_type
The type of memory used by the MemLSTM. Can be one of ‘hc’, ‘h’, ‘c’, or ‘id’.
- Type: str
bidirectional
Indicates if the LSTM layers are bidirectional.
Type: bool
Returns: A string summarizing the configuration of the MemLSTM.
Return type: str
########### Examples
>>> mem_lstm = MemLSTM(hidden_size=128, bidirectional=True, mem_type='hc')
>>> print(mem_lstm.extra_repr())
'Mem_type: hc, bidirectional: True'
forward(hc, S)
Forward pass for the MemLSTM layer.
This method processes the hidden and cell states from the SegLSTM and applies the MemLSTM transformation based on the specified memory type. The function handles both identity mode and different memory types (‘hc’, ‘h’, ‘c’) to compute the new hidden and cell states.
- Parameters:
- hc (tuple) – A tuple containing the hidden and cell states from SegLSTM. Each state should have the shape (d, B*S, H), where:
- d: number of directions (1 for unidirectional, 2 for bidirectional)
- B: batch size
- S: number of segments
- H: hidden size
- S (int) – Number of segments in the SegLSTM.
- hc (tuple) – A tuple containing the hidden and cell states from SegLSTM. Each state should have the shape (d, B*S, H), where:
- Returns: A tuple containing the updated hidden and cell states. : The shape will be (B*S, d, H) for each state.
- Return type: tuple
####### NOTE If self.mem_type is set to ‘id’, the function returns the input hidden and cell states unchanged. If self.bidirectional is False, the output will be modified for causal processing.
########### Examples
>>> mem_lstm = MemLSTM(hidden_size=128, mem_type='hc')
>>> h = torch.randn(2, 10, 128) # Example hidden state
>>> c = torch.randn(2, 10, 128) # Example cell state
>>> hc = (h, c)
>>> S = 5 # Example number of segments
>>> output_hc = mem_lstm.forward(hc, S)
forward_one_step(hc, state)
Forward one step in the MemLSTM processing.
This method computes the next hidden and cell states given the current hidden and cell states. It processes the input based on the memory type specified during the initialization of the MemLSTM class.
- Parameters:
- hc (tuple) – A tuple containing the current hidden state (h) and cell state (c). The shapes are expected to be (d, B, H), where d is the number of directions (1 for unidirectional, 2 for bidirectional), B is the batch size, and H is the hidden size.
- state (list) – A list containing the hidden states for the LSTM layers, which should match the structure defined during the initialization.
- Returns: A tuple containing the updated hidden and cell states : (hc) and the updated state list. The shapes of the hidden and cell states will remain (d, B, H).
- Return type: tuple
####### NOTE This method does not modify the states when mem_type is set to “id”.
########### Examples
>>> mem_lstm = MemLSTM(hidden_size=128, mem_type='hc')
>>> hc = (torch.zeros(2, 4, 128), torch.zeros(2, 4, 128)) # Example shapes
>>> state = [torch.zeros(2, 4, 128), torch.zeros(2, 4, 128)] # Example states
>>> new_hc, new_state = mem_lstm.forward_one_step(hc, state)
- Raises:ValueError – If the mem_type is not one of the supported types.