espnet2.layers.create_adapter_utils.replace_module
Less than 1 minute
espnet2.layers.create_adapter_utils.replace_module
espnet2.layers.create_adapter_utils.replace_module(parent_module: Module, child_name: str, old_module: Module, new_module: Module)
Replace the target module within a parent module with a new module.
This function identifies the specified child module within the parent module and replaces it with a new module. It also copies the weights and biases from the old module to the new module and ensures that the new module is moved to the same device as the old module.
- Parameters:
- parent_module (torch.nn.Module) – The parent module containing the child module to be replaced.
- child_name (str) – The name of the child module to be replaced.
- old_module (torch.nn.Module) – The module to be replaced.
- new_module (torch.nn.Module) – The new module that will replace the old module.
- Raises:AttributeError – If the specified child_name does not exist in the parent_module.
NOTE
- This function currently does not handle the addition of hooks or the requires_grad attribute for the new module.
Examples
>>> parent = torch.nn.Sequential(
... torch.nn.Linear(10, 5),
... torch.nn.ReLU()
... )
>>> old_module = parent[0]
>>> new_module = torch.nn.Linear(10, 5)
>>> replace_module(parent, '0', old_module, new_module)
>>> assert isinstance(parent[0], torch.nn.Linear) # New module is in place