diff --git a/graph_of_thoughts/controller/llamachat_hf.py b/graph_of_thoughts/controller/llamachat_hf.py index 0065cf8..d423a50 100644 --- a/graph_of_thoughts/controller/llamachat_hf.py +++ b/graph_of_thoughts/controller/llamachat_hf.py @@ -8,7 +8,6 @@ import os import torch -import transformers from typing import List, Dict, Union from .abstract_language_model import AbstractLanguageModel @@ -48,6 +47,7 @@ class Llama2HF(AbstractLanguageModel): # Important: must be done before importing transformers os.environ["TRANSFORMERS_CACHE"] = self.config["cache_dir"] + import transformers hf_model_id = f"meta-llama/{self.model_id}" model_config = transformers.AutoConfig.from_pretrained(hf_model_id)