espnet2.asr.state_spaces.components.ReversibleInstanceNorm1dInput
espnet2.asr.state_spaces.components.ReversibleInstanceNorm1dInput
class espnet2.asr.state_spaces.components.ReversibleInstanceNorm1dInput(d, transposed=False)
Bases: Module
Reversible Instance Normalization for 1D inputs.
This module implements reversible instance normalization for 1D inputs. It computes the mean and standard deviation of the input tensor and normalizes the input accordingly. The normalization parameters can be reversed to retrieve the original input.
transposed
Indicates whether the input is in transposed form (BDL) or not (BLD).
- Type: bool
norm
Instance normalization layer.
Type: nn.InstanceNorm1d
Parameters:
- d (int) – The number of features in the input tensor.
- transposed (bool , optional) – If True, expects input shape (B, D, L). Default is False (input shape is (B, L, D)).
Returns: The normalized input tensor.
Return type: Tensor
####### Examples
>>> norm_layer = ReversibleInstanceNorm1dInput(d=10, transposed=False)
>>> input_tensor = torch.randn(32, 5, 10) # (Batch, Length, Features)
>>> normalized_tensor = norm_layer(input_tensor)
>>> print(normalized_tensor.shape)
torch.Size([32, 5, 10])
NOTE
This module is designed to work with transposed and non-transposed inputs, allowing flexibility in handling different data formats.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
ReversibleInstanceNorm1dInput class.
This class applies a reversible instance normalization operation over the input tensor. It computes the mean and standard deviation along the specified dimensions and normalizes the input accordingly. The normalization is reversible, allowing the original input to be reconstructed later using the stored statistics.
transposed
A flag indicating whether the input tensor is in a transposed format (BDL) or not (BLD).
- Type: bool
norm
The instance normalization layer.
Type: nn.InstanceNorm1d
Parameters:
- d (int) – The number of features in the input tensor.
- transposed (bool) – Indicates if the input is in a transposed format. Default is False.
Returns: The normalized tensor, with the same shape as the input.
Return type: Tensor
####### Examples
>>> layer = ReversibleInstanceNorm1dInput(d=64, transposed=False)
>>> input_tensor = torch.randn(32, 64, 10) # (batch_size, features, lengths)
>>> output_tensor = layer(input_tensor)
>>> print(output_tensor.shape) # Output: torch.Size([32, 64, 10])
NOTE
This normalization is particularly useful in scenarios where maintaining the input distribution is critical, such as in reversible networks.