espnet2.enh.layers.ncsnpp_utils.up_or_down_sampling.get_weight
Less than 1 minute
espnet2.enh.layers.ncsnpp_utils.up_or_down_sampling.get_weight
espnet2.enh.layers.ncsnpp_utils.up_or_down_sampling.get_weight(module, shape, weight_var='weight', kernel_init=None)
Get/create weight tensor for a convolution or fully-connected layer.
This function retrieves or initializes the weight tensor for a specified layer within a neural network module. It allows for the creation of weights with a given shape and optional kernel initialization.
- Parameters:
- module (nn.Module) – The neural network module from which to retrieve or create the weight tensor.
- shape (tuple) – The desired shape of the weight tensor.
- weight_var (str , optional) – The variable name of the weight tensor (default is “weight”).
- kernel_init (callable , optional) – A function to initialize the weight tensor. If None, the weights will be uninitialized.
- Returns: The initialized or retrieved weight tensor.
- Return type: torch.Tensor
Examples
>>> import torch
>>> module = nn.Linear(10, 5)
>>> weight = get_weight(module, (5, 10))
>>> weight.shape
torch.Size([5, 10])
>>> def custom_init(shape):
... return torch.randn(shape) * 0.01
>>> weight = get_weight(module, (5, 10), kernel_init=custom_init)
>>> weight.mean()
tensor(..., grad_fn=<MeanBackward0>)
NOTE
The function assumes that the module passed in has the capability to store a parameter with the specified weight variable name.