espnet2.asr.decoder.hugging_face_transformers_decoder.get_hugging_face_model_lm_head
Less than 1 minute
espnet2.asr.decoder.hugging_face_transformers_decoder.get_hugging_face_model_lm_head
espnet2.asr.decoder.hugging_face_transformers_decoder.get_hugging_face_model_lm_head(model)
Get the language model head from a Hugging Face Transformers model.
This function retrieves the language model head from a given Transformers model. The function checks for the presence of the lm_head or embed_out attribute in the model and returns the appropriate head. If neither attribute is found, an exception is raised.
- Parameters:model – A Hugging Face Transformers model instance from which to extract the language model head.
- Returns: The language model head of the specified model.
- Raises:Exception – If neither lm_head nor embed_out attributes can be found in the model.
Examples
>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> lm_head = get_hugging_face_model_lm_head(model)
>>> print(lm_head) # This will print the language model head of the model.
NOTE
Ensure that the model is a valid Hugging Face Transformers model instance before calling this function.