espnet2.layers.create_adapter_utils.get_submodules
Less than 1 minute
espnet2.layers.create_adapter_utils.get_submodules
espnet2.layers.create_adapter_utils.get_submodules(model: Module, key: str)
Retrieve the submodules of a specified key from the given model.
This function navigates through the hierarchical structure of a PyTorch module to locate and return the parent module, the name of the target submodule, and the target submodule itself based on the provided key.
- Parameters:
- model (torch.nn.Module) – The parent model from which to retrieve the submodules.
- key (str) – The key representing the path to the target submodule in the model, using dot notation.
- Returns: A tuple containing: : - parent_module (torch.nn.Module): The parent module containing the target submodule.
- target_name (str): The name of the target submodule.
- target_module (torch.nn.Module): The target submodule itself.
- Return type: tuple
Examples
>>> import torch
>>> class MyModel(torch.nn.Module):
... def __init__(self):
... super().__init__()
... self.layer1 = torch.nn.Linear(10, 5)
... self.layer2 = torch.nn.Linear(5, 2)
...
... def get_submodule(self, key):
... return getattr(self, key)
>>> model = MyModel()
>>> parent, name, module = get_submodules(model, "layer2")
>>> print(name) # Output: layer2
>>> print(module) # Output: Linear(in_features=5, out_features=2, bias=True)