espnet2.gan_codec.shared.encoder.seanet.SLSTM
espnet2.gan_codec.shared.encoder.seanet.SLSTM
class espnet2.gan_codec.shared.encoder.seanet.SLSTM(dimension: int, num_layers: int = 2, skip: bool = True)
Bases: Module
SLSTM is a custom Long Short-Term Memory (LSTM) module designed to handle
inputs with a convolutional layout. It abstracts the complexities of hidden state management and input data arrangement, providing a simplified interface for sequential data processing.
skip
If True, adds the input to the output of the LSTM for a skip connection. Defaults to True.
- Type: bool
lstm
The LSTM layer used for processing the input data.
Type: nn.LSTM
Parameters:
- dimension (int) – The number of expected features in the input (also the number of output features).
- num_layers (int , optional) – The number of recurrent layers. Defaults to 2.
- skip (bool , optional) – Whether to use a skip connection by adding the input to the output. Defaults to True.
Returns: The output tensor, with the same shape as the input tensor but with the features processed by the LSTM.
Return type: torch.Tensor
####### Examples
>>> model = SLSTM(dimension=128)
>>> input_tensor = torch.randn(10, 32, 128) # (batch_size, seq_len, features)
>>> output_tensor = model(input_tensor)
>>> print(output_tensor.shape) # Output shape: (10, 32, 128)
NOTE
The input tensor is expected to be in the shape of (batch_size, seq_len, features) and is permuted to (seq_len, batch_size, features) before being passed to the LSTM layer.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Applies the LSTM to the input tensor and optionally adds the input
tensor to the output for skip connections.
The input tensor is expected to be in a convolutional layout with the shape (batch_size, channels, sequence_length). The output will have the same shape as the input tensor.
- Parameters:
- x (torch.Tensor) – Input tensor with shape (batch_size, channels,
- sequence_length**)****.**
- Returns: Output tensor after applying LSTM and skip connection if enabled.
- Return type: torch.Tensor
####### Examples
>>> slstm = SLSTM(dimension=128)
>>> input_tensor = torch.randn(32, 128, 100) # (batch_size, channels, seq_len)
>>> output_tensor = slstm(input_tensor)
>>> print(output_tensor.shape) # Should be (32, 128, 100)
NOTE
The input tensor is permuted to (sequence_length, batch_size, channels) before being passed to the LSTM.