espnet2.asr.state_spaces.components.ReversibleInstanceNorm1dOutput
espnet2.asr.state_spaces.components.ReversibleInstanceNorm1dOutput
class espnet2.asr.state_spaces.components.ReversibleInstanceNorm1dOutput(norm_input)
Bases: Module
ReversibleInstanceNorm1dOutput is a module that applies reversible instance normalization to a 1D input tensor.
This module takes the output of the ReversibleInstanceNorm1dInput class and reverses the normalization process by applying the stored mean and standard deviation. This is particularly useful in reversible neural networks where the original input needs to be recovered after normalization.
transposed
Indicates if the input tensor is transposed.
- Type: bool
weight
Learnable weight parameter for normalization.
- Type: torch.nn.Parameter
bias
Learnable bias parameter for normalization.
- Type: torch.nn.Parameter
norm_input
Instance of the input normalization module that contains the mean and standard deviation used during normalization.
Parameters:norm_input (ReversibleInstanceNorm1dInput) – An instance of ReversibleInstanceNorm1dInput which provides the mean and standard deviation for the normalization.
Returns: The output tensor after applying the reverse normalization.
Return type: Tensor
####### Examples
>>> norm_input = ReversibleInstanceNorm1dInput(d=10)
>>> output_layer = ReversibleInstanceNorm1dOutput(norm_input)
>>> input_tensor = torch.randn(32, 10, 5) # Batch size 32, 10 features, 5 length
>>> normalized_tensor = norm_input(input_tensor)
>>> output_tensor = output_layer(normalized_tensor)
NOTE
The input tensor shape should match the shape expected by the ReversibleInstanceNorm1dInput class, either (B, L, D) or (B, D, L) depending on the transposed attribute.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Output module for reversible instance normalization.
This module applies the inverse transformation of the instance normalization process that was performed in the corresponding input module. It takes the normalized output and re-scales it back to the original space using the mean and standard deviation computed during the forward pass of the input module.
transposed
Indicates if the input data is transposed.
- Type: bool
weight
The learnable weight parameter from the input normalization module.
- Type: torch.Tensor
bias
The learnable bias parameter from the input normalization module.
- Type: torch.Tensor
norm_input
The input normalization module which provides the mean and standard deviation for re-scaling.
Parameters:norm_input (ReversibleInstanceNorm1dInput) – The instance normalization module used to calculate the mean and standard deviation.
Returns: The output tensor after applying the inverse normalization.
Return type: Tensor
####### Examples
>>> norm_input = ReversibleInstanceNorm1dInput(d=10)
>>> norm_output = ReversibleInstanceNorm1dOutput(norm_input)
>>> x = torch.randn(5, 10, 20) # Example input tensor
>>> normalized = norm_input(x) # Apply normalization
>>> output = norm_output(normalized) # Apply inverse normalization
>>> assert torch.allclose(x, output) # Check if we recover the original x