espnet2.enh.layers.complexnn.NavieComplexLSTM
espnet2.enh.layers.complexnn.NavieComplexLSTM
class espnet2.enh.layers.complexnn.NavieComplexLSTM(input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False)
Bases: Module
A naive implementation of a complex-valued Long Short-Term Memory (LSTM).
This LSTM processes complex-valued input by separating the real and imaginary parts and passing them through individual LSTMs. The outputs from these LSTMs are combined to produce the final complex-valued output.
bidirectional
If True, the LSTM will be bidirectional.
- Type: bool
input_dim
The input dimension for the LSTM, half of the input size.
- Type: int
rnn_units
The number of hidden units for the LSTM, half of the hidden size.
- Type: int
real_lstm
LSTM for processing the real part of the input.
- Type: nn.LSTM
imag_lstm
LSTM for processing the imaginary part of the input.
- Type: nn.LSTM
projection_dim
Dimension for the output projection layer.
- Type: int or None
r_trans
Linear transformation for the real output.
- Type: nn.Linear or None
i_trans
Linear transformation for the imaginary output.
Type: nn.Linear or None
Parameters:
- input_size (int) – The size of the input (must be even for complex inputs).
- hidden_size (int) – The number of hidden units in the LSTM (must be even).
- projection_dim (int or None) – The dimension of the projection layer (if None, no projection is applied).
- bidirectional (bool) – If True, the LSTM will be bidirectional.
- batch_first (bool) – If True, input and output tensors are provided as (batch, seq, feature).
Returns: A list containing the real and imaginary outputs.
Return type: list
Yields: None
Raises:ValueError – If the input size is not even or if the hidden size is not even.
######### Examples
>>> lstm = NavieComplexLSTM(input_size=4, hidden_size=4)
>>> inputs = [torch.randn(10, 2), torch.randn(10, 2)] # 10 time steps, 2 features
>>> outputs = lstm(inputs)
>>> real_out, imag_out = outputs
>>> real_out.shape, imag_out.shape
(torch.Size([10, 2]), torch.Size([10, 2]))
####### NOTE The inputs must be in the form of a list or a single tensor that can be split into real and imaginary components.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
flatten_parameters()
Flatten the parameters of the LSTM layers for efficient training.
This method is particularly useful for optimizing the performance of LSTM layers when using packed sequences, as it allows the LSTMs to use a single contiguous memory block for their weights.
The method calls flatten_parameters() on both the real and imaginary LSTM layers to ensure their parameters are properly flattened.
real_lstm
The real-valued LSTM layer.
- Type: nn.LSTM
imag_lstm
The imaginary-valued LSTM layer.
- Type: nn.LSTM
######### Examples
>>> model = NavieComplexLSTM(input_size=4, hidden_size=8)
>>> model.flatten_parameters()
####### NOTE This method should be called before training when using packed sequences to ensure optimal performance.
- Raises:RuntimeError – If the LSTM layers have not been properly initialized.
forward(inputs)
Computes the forward pass of the NavieComplexLSTM.
This method takes complex-valued input and processes it through two separate LSTM layers for the real and imaginary parts. It then combines the outputs to produce the final complex output.
- Parameters:inputs (Union *[*torch.Tensor , List *[*torch.Tensor ] ]) – A tensor or a list containing the real and imaginary parts of the input. If a tensor, it should be of shape (seq_len, batch_size, input_size) where input_size is the total number of input features (real + imaginary). If a list, it should contain two tensors: the first for the real part and the second for the imaginary part.
- Returns: A list containing two tensors: the processed real and imaginary parts. Each tensor will have the shape (seq_len, batch_size, output_size) where output_size is determined by the hidden size and whether projection_dim is set.
- Return type: List[torch.Tensor]
######### Examples
>>> lstm = NavieComplexLSTM(input_size=4, hidden_size=4)
>>> real_input = torch.randn(10, 2, 2) # (seq_len, batch_size, real_dim)
>>> imag_input = torch.randn(10, 2, 2) # (seq_len, batch_size, imag_dim)
>>> outputs = lstm([real_input, imag_input])
>>> real_out, imag_out = outputs
>>> print(real_out.shape) # Should match (10, 2, output_size)
>>> print(imag_out.shape) # Should match (10, 2, output_size)
####### NOTE The input size must be divisible by 2, as it expects the real and imaginary parts to be interleaved.
- Raises:
- ValueError – If the input dimensions do not match the expected
- input shape or if the input size is not divisible by 2. –