espnet2.asr.state_spaces.residual.Highway
espnet2.asr.state_spaces.residual.Highway
class espnet2.asr.state_spaces.residual.Highway(*args, scaling_correction=False, elemwise=False)
Bases: Residual
Highway Residual connection with learned gating mechanisms.
This class implements a Highway connection that combines input tensors using learned affine transformations and gating mechanisms. The Highway layer can apply a scaling correction to its output and supports element-wise multiplication of the residual input.
scaling_correction
A scaling factor applied to the output, defaulting to 1.732 if enabled, otherwise 1.0.
- Type: float
elemwise
If True, the residual connection is computed element-wise; otherwise, a linear transformation is applied.
Type: bool
Parameters:
- *args – Variable length argument list passed to the parent Residual class.
- scaling_correction (bool) – Indicates whether to apply scaling correction to the output (default: False).
- elemwise (bool) – If True, applies element-wise multiplication to the residual input (default: False).
Returns: The output tensor resulting from the Highway connection.
Return type: Tensor
####### Examples
>>> highway_layer = Highway(i_layer=1, d_input=256, d_model=256)
>>> x = torch.randn(32, 256) # Batch of 32
>>> y = torch.randn(32, 256) # Batch of 32
>>> output = highway_layer(x, y)
>>> print(output.shape)
torch.Size([32, 256])
NOTE
The Highway layer can be particularly useful in deep networks to enable training by mitigating the vanishing gradient problem.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, y, transposed=False)
Perform the forward pass of the Highway residual connection.
This method computes the output of the Highway layer by applying a learnable transformation to the input tensors x and y. It utilizes a gating mechanism, controlled by the sigmoid function, to blend the input tensors based on their learned weights.
- Parameters:
- x (torch.Tensor) – The input tensor to the layer, typically the previous layer’s output.
- y (torch.Tensor) – The tensor to be combined with x, usually the output of another layer or transformation.
- transposed (bool , optional) – A flag indicating whether the y tensor should be treated as transposed. Defaults to False.
- Returns: The output tensor resulting from the combination of x and y according to the Highway mechanism.
- Return type: torch.Tensor
####### Examples
>>> highway_layer = Highway(i_layer=0, d_input=64, d_model=64)
>>> x = torch.randn(10, 64) # Batch of 10 with 64 features
>>> y = torch.randn(10, 64) # Another batch of 10 with 64 features
>>> output = highway_layer.forward(x, y)
>>> print(output.shape)
torch.Size([10, 64])
NOTE
The scaling correction can be enabled during the initialization of the Highway layer, which will adjust the output accordingly.