espnet2.asr.state_spaces.base.TransposedModule
Less than 1 minute
espnet2.asr.state_spaces.base.TransposedModule
espnet2.asr.state_spaces.base.TransposedModule(module)
Transposed module.
This function serves as a decorator that wraps a SequenceModule class to allow it to accept a transposed parameter. When transposed is set to True, the input tensor’s dimensions are transposed before and after the forward pass, enabling compatibility with different input shapes.
espnet2.asr.state_spaces.base.transposed
Indicates whether the input should be transposed.
Type: bool
Parameters:module (type) – A subclass of SequenceModule that will be wrapped.
Returns: A subclass of module with added transposition capabilities.
Return type: type
Examples
>>> @TransposedModule
>>> class MyModule(SequenceModule):
>>> def __init__(self, d_model):
>>> super().__init__()
>>> self.d_model = d_model
>>> self.d_output = d_model
>>>
>>> my_module = MyModule(d_model=128, transposed=True)
>>> input_tensor = torch.randn(32, 128, 10) # Shape (n_batch, d_model, l_sequence)
>>> output, state = my_module(input_tensor)
NOTE
This decorator modifies the behavior of the forward method to handle the transposed state appropriately.