espnet2.asr.state_spaces.residual.Residual
espnet2.asr.state_spaces.residual.Residual
class espnet2.asr.state_spaces.residual.Residual(i_layer, d_input, d_model, alpha=1.0, beta=1.0)
Bases: Module
Residual connection with constant affine weights.
This class implements a residual connection that can simulate various types of residual behaviors including standard residual, no residual, and “constant gates”. The residual connection is parameterized by constant affine weights, allowing for flexibility in how the input and output tensors are combined.
i_layer
The index of the current layer.
- Type: int
d_input
The dimensionality of the input tensor.
- Type: int
d_model
The dimensionality of the model’s output tensor.
- Type: int
alpha
The scaling factor for the input tensor.
- Type: float
beta
The scaling factor for the output tensor.
Type: float
Parameters:
- i_layer (int) – The index of the layer in the network.
- d_input (int) – The input dimensionality.
- d_model (int) – The output dimensionality.
- alpha (float , optional) – Scaling factor for the input. Defaults to 1.0.
- beta (float , optional) – Scaling factor for the output. Defaults to 1.0.
Returns: The output tensor resulting from the residual connection.
Return type: Tensor
Raises:AssertionError – If d_input is not equal to d_model and alpha is not 0.0.
####### Examples
>>> residual = Residual(i_layer=1, d_input=128, d_model=128, alpha=0.5)
>>> x = torch.randn(10, 128)
>>> y = torch.randn(10, 128)
>>> output = residual(x, y, transposed=False)
>>> print(output.shape)
torch.Size([10, 128])
NOTE
This implementation is part of a larger framework for state-space models in speech processing tasks.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
property d_output
forward(x, y, transposed)
Compute the output of the residual connection.
This method performs a residual operation based on the input tensors x and y, scaled by the parameters alpha and beta. The output is determined by whether the input y is to be transposed or not. The operation can simulate standard residual connections or other variations based on the values of alpha and beta.
- Parameters:
- x (torch.Tensor) – The input tensor representing the main branch.
- y (torch.Tensor) – The input tensor representing the residual branch.
- transposed (bool) – A flag indicating whether to transpose the input tensor y before applying the operation.
- Returns: The resulting tensor after applying the residual connection.
- Return type: torch.Tensor
####### Examples
>>> residual = Residual(i_layer=1, d_input=10, d_model=10, alpha=1.0)
>>> x = torch.randn(5, 10)
>>> y = torch.randn(5, 10)
>>> output = residual.forward(x, y, transposed=False)
>>> output.shape
torch.Size([5, 10])
NOTE
The output will be alpha * x + beta * y if alpha is non-zero; otherwise, it will return y scaled by beta.
- Raises:
- ValueError – If the shapes of x and y do not match the expected
- dimensions based on d_input and d_model. –