espnet2.asr.state_spaces.residual.Affine
espnet2.asr.state_spaces.residual.Affine
class espnet2.asr.state_spaces.residual.Affine(*args, scalar=True, gamma=0.0, **kwargs)
Bases: Residual
Residual connection with learnable scalar multipliers on the main branch.
This class implements a residual connection that includes learnable scalar multipliers applied to the input of the residual branch. It allows for flexibility in how the residual connection is formed, making it possible to use a single scalar multiplier or one per dimension, depending on the scalar attribute.
scalar
If True, uses a single scalar multiplier; if False, uses one multiplier per dimension.
- Type: bool
gamma
A scaling factor that influences the initialization of the affine parameters.
- Type: float
affine
A learnable parameter representing the scalar multipliers.
Type: torch.nn.Parameter
Parameters:
- *args – Variable length argument list for initializing the parent Residual class.
- scalar (bool , optional) – Determines if a single scalar multiplier or one per dimension is used. Default is True.
- gamma (float , optional) – The power for scaling initialization. Default is 0.0.
- **kwargs – Additional keyword arguments for the parent class.
Returns: None
####### Examples
>>> affine_layer = Affine(i_layer=1, d_input=4, d_model=4,
... alpha=1.0, beta=1.0, scalar=True,
... gamma=0.0)
>>> x = torch.rand(2, 4)
>>> y = torch.rand(2, 4)
>>> output = affine_layer(x, y, transposed=False)
>>> print(output.shape) # Output: torch.Size([2, 4])
NOTE
The multipliers are initialized to scale * layer_num**(-power) based on the layer number and the provided gamma parameter.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, y, transposed)
Computes the forward pass of the affine residual connection.
This method takes two inputs, x and y, and computes the output using the defined affine transformation. If transposed is True, the learnable parameters are reshaped accordingly.
- Parameters:
- x (torch.Tensor) – The input tensor from the previous layer.
- y (torch.Tensor) – The input tensor to be added, typically from another branch.
- transposed (bool) – A flag indicating whether to apply the transposition to the learnable parameters.
- Returns: The output tensor after applying the affine transformation and residual connection.
- Return type: torch.Tensor
####### Examples
>>> affine_layer = Affine(i_layer=1, d_input=4, d_model=4)
>>> x = torch.randn(2, 4) # Batch of 2 with 4 features
>>> y = torch.randn(2, 4)
>>> output = affine_layer.forward(x, y, transposed=False)
>>> print(output.shape)
torch.Size([2, 4])
NOTE
The learnable parameters are initialized based on the layer index and a specified gamma value, which controls the scaling factor.
- Raises:
- ValueError – If the dimensions of x and y do not match the
- expected input sizes. –