espnet2.asr.state_spaces.base.SequenceModule
espnet2.asr.state_spaces.base.SequenceModule
class espnet2.asr.state_spaces.base.SequenceModule(*args, **kwargs)
Bases: Module
SequenceModule is an abstract class that defines the interface for sequence models in a neural network framework. It transforms an input tensor of shape (n_batch, l_sequence, d_model) into an output tensor of shape (n_batch, l_sequence, d_output).
This class requires implementations of the forward method and the d_model and d_output attributes to facilitate a standard sequence-to- sequence transformation. Additionally, it provides optional methods for state management, enabling recurrent processing and state decoding.
d_model
The model dimension, generally the same as the input dimension. This attribute must be set during initialization.
- Type: int
d_output
The output dimension of the model. This attribute must also be set during initialization.
- Type: int
forward(x, state=None, **kwargs)
Performs a forward pass through the model, mapping input tensors to output tensors.
default_state(*batch_shape, device=None)
Creates an initial state for a batch of inputs.
step(x, state=None, **kwargs)
Processes one step of the input sequence recurrently.
state_to_tensor()
A property that returns a function to map the hidden state to a tensor.
d_state()
A property that returns the dimension of the output of state_to_tensor.
- Raises:NotImplementedError – If the required attributes or methods are not implemented in a subclass.
############# Examples
To create a custom sequence model, subclass SequenceModule and implement the required methods:
``
`
python class MySequenceModel(SequenceModule):
def __init__(self, d_model, d_output): : super()._init_() self.d_model = d_model self.d_output = d_output
def forward(self, x, state=None): : # Implement the transformation logic here return transformed_x, new_state
``
`
When using the model:
python model = MySequenceModel(d_model=128, d_output=64) output, state = model.forward(input_tensor)
######## NOTE This class is part of the ESPnet2 ASR framework and serves as a foundational building block for various sequence models.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
property d_model
Model dimension (generally same as input dimension).
This attribute is required for all SequenceModule instantiations. It is used by the rest of the pipeline (e.g. model backbone, encoder) to track the internal shapes of the full model.
property d_output
Output dimension of model.
This attribute is required for all instances of SequenceModule. It is used by the rest of the pipeline (e.g., model backbone, decoder) to track the internal shapes of the full model. The dimension must be specified during the instantiation of the model. If not set, a NotImplementedError will be raised.
- Returns: The output dimension of the model.
- Return type: int
- Raises:
- NotImplementedError – If d_output is not specified during
- instantiation. –
############# Examples
>>> model = SomeSequenceModel(d_model=128, d_output=64)
>>> model.d_output
64
>>> model = SomeSequenceModel(d_model=128) # d_output not set
>>> model.d_output # Raises NotImplementedError
property d_state
Return dimension of output of self.state_to_tensor.
default_state(*batch_shape, device=None)
Create initial state for a batch of inputs.
This method is intended to be overridden by subclasses to provide an appropriate initial state based on the input batch shape and device. The default implementation returns None, indicating that no initial state is used.
- Parameters:
- *batch_shape – Variable-length argument list representing the shape of the batch. This can be used to determine the size of the initial state.
- device (torch.device , optional) – The device on which to create the initial state. If not specified, the default device will be used.
- Returns: Initial state tensor for the given batch shape, or None if no state is required.
############# Examples
>>> module = SequenceIdentity(d_model=128)
>>> initial_state = module.default_state(32, 10) # Batch of 32, seq len 10
>>> print(initial_state) # Should output: None
######## NOTE Subclasses that require a state should implement this method to return a valid tensor based on the specified batch shape and device.
forward(x, state=None, **kwargs)
Forward pass.
A sequence-to-sequence transformation with an optional state.
This method takes an input tensor of shape (batch, length, self.d_model) and transforms it to (batch, length, self.d_output). The function also returns a “state” which can contain additional information, such as hidden states for RNN and SSM layers. Some transformer layers (e.g., Transformer-XL) may also utilize this state.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (batch, length, d_model).
- state (optional) – The initial state for recurrent processing, which may be None if not applicable.
- **kwargs – Additional keyword arguments that may be relevant for specific implementations.
- Returns: A tuple containing: : - Output tensor of shape (batch, length, d_output).
- Updated state (if applicable), which can be None.
- Return type: Tuple[torch.Tensor, Optional]
############# Examples
>>> model = SequenceIdentity(d_model=128)
>>> input_tensor = torch.randn(32, 10, 128) # (batch, length, d_model)
>>> output, state = model.forward(input_tensor)
>>> print(output.shape) # Should output: torch.Size([32, 10, 128])
######## NOTE Implementations of this method should ensure that the output shape aligns with the specified d_output attribute.
- Raises:NotImplementedError – If the method is not implemented in a derived class.
property state_to_tensor
@property def state_to_tensor(self):
step(x, state=None, **kwargs)
Step the model recurrently for one step of the input sequence.
This method processes a single input step and updates the model’s state. It is typically used in recurrent architectures, such as RNNs, to compute the next output based on the current input and the previous state.
The method generally has the following signature: (B, H1) -> (B, H2), where:
- B is the batch size
- H1 is the dimension of the input
- H2 is the dimension of the output
Parameters:
- x (torch.Tensor) – Input tensor of shape (B, H1), where B is the batch size and H1 is the input dimension.
- state (optional) – The previous hidden state of the model, which can be used for recurrent processing.
- **kwargs – Additional keyword arguments that may be used by specific implementations.
Returns: The output tensor of shape (B, H2), where H2 is : the output dimension of the model.
Optional: Updated state after processing the input.
Return type: torch.Tensor
Raises:NotImplementedError – If the method is not implemented in a subclass.
############# Examples
>>> model = MyRNNModel(d_model=128)
>>> input_tensor = torch.randn(32, 128) # Batch size of 32
>>> output, new_state = model.step(input_tensor, state=prev_state)
######## NOTE This method must be overridden in subclasses of SequenceModule to provide specific recurrent behavior.