espnet2.asr.state_spaces.residual.DecayResidual
espnet2.asr.state_spaces.residual.DecayResidual
class espnet2.asr.state_spaces.residual.DecayResidual(*args, power=0.5, l2=True)
Bases: Residual
Residual connection that can decay the linear combination depending on depth.
This class implements a residual connection where the contribution of the input tensor can decay based on the layer depth. It adjusts the weights of the residual connection dynamically, allowing for more controlled flow of information through deeper layers.
power
The exponent used to compute the decay factor for the residual connection. A higher power results in faster decay.
- Type: float
l2
If True, uses L2 normalization for the alpha coefficient; otherwise, uses a linear decay.
Type: bool
Parameters:
- *args – Positional arguments to be passed to the parent class.
- power (float) – Exponent for decay computation (default is 0.5).
- l2 (bool) – Flag to determine if L2 normalization is used (default is True).
Returns: The output of the residual connection after applying decay.
Return type: Tensor
####### Examples
>>> decay_residual = DecayResidual(i_layer=2, d_input=10, d_model=10)
>>> x = torch.randn(1, 10)
>>> y = torch.randn(1, 10)
>>> output = decay_residual(x, y, transposed=False)
>>> print(output.shape)
torch.Size([1, 10])
NOTE
The behavior of this class is influenced by the layer index (i_layer) at which it is instantiated. As the layer index increases, the effect of the decay will be more pronounced.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, y, transposed)
Computes the output of the DecayResidual layer, which applies a residual connection that can decay the linear combination of inputs based on the layer’s depth.
The output is computed as: : output = alpha * x + beta * y
where alpha and beta are determined based on the layer index and the specified power. The decay factor allows for a controlled blending of the input tensor x and the residual tensor y, with the possibility of scaling down the contribution of the previous layer’s output.
- Parameters:
- x (torch.Tensor) – The input tensor to the layer, typically the output from the previous layer.
- y (torch.Tensor) – The residual tensor, typically the output from a different path in the network.
- transposed (bool) – A flag indicating whether the inputs should be treated as transposed (e.g., for different dimensions).
- Returns: The resulting tensor after applying the decay : residual connection.
- Return type: torch.Tensor
####### Examples
>>> import torch
>>> decay_residual = DecayResidual(i_layer=2, d_input=3, d_model=3)
>>> x = torch.randn(5, 3) # Batch of 5 with 3 features
>>> y = torch.randn(5, 3) # Batch of 5 with 3 features
>>> output = decay_residual(x, y, transposed=False)
>>> print(output.shape)
torch.Size([5, 3])
NOTE
The power attribute controls the rate of decay for the linear combination, where higher values result in faster decay.
- Raises:ValueError – If the dimensions of x and y do not match.