espnet2.asr.state_spaces.components.get_initializer
espnet2.asr.state_spaces.components.get_initializer
espnet2.asr.state_spaces.components.get_initializer(name, activation=None)
Get the appropriate weight initializer based on the specified name and activation function.
This function returns a callable that initializes weights according to the specified initialization method and takes into account the type of activation function being used. It supports several initialization methods including ‘uniform’, ‘normal’, ‘xavier’, ‘zero’, and ‘one’. If the activation function is not recognized, it raises a NotImplementedError.
- Parameters:
- name (str) – The name of the initializer. Supported values are ‘uniform’, ‘normal’, ‘xavier’, ‘zero’, and ‘one’.
- activation (str , optional) – The activation function to consider for the initialization. Supported values include ‘relu’, ‘tanh’, ‘sigmoid’, ‘gelu’, ‘swish’, and ‘linear’. Defaults to None.
- Returns: A callable that initializes weights based on the specified initializer and activation function.
- Return type: Callable
- Raises:
- NotImplementedError – If the specified initializer name or activation
- function is not supported. –
Examples
Get a uniform initializer for ReLU activation
initializer = get_initializer(“uniform”, activation=”relu”)
Apply the initializer to a tensor
weight_tensor = torch.empty(3, 5) initializer(weight_tensor)
Get a Xavier initializer for sigmoid activation
initializer = get_initializer(“xavier”, activation=”sigmoid”) weight_tensor = torch.empty(4, 4) initializer(weight_tensor)