Fix importing transformers library before setting of env variables

This commit is contained in:
Nils Blach 2023-10-18 10:34:27 +09:00 committed by Nils Blach
parent b41d653f21
commit 8590e4bd65

View File

@ -8,7 +8,6 @@
import os
import torch
import transformers
from typing import List, Dict, Union
from .abstract_language_model import AbstractLanguageModel
@ -48,6 +47,7 @@ class Llama2HF(AbstractLanguageModel):
# Important: must be done before importing transformers
os.environ["TRANSFORMERS_CACHE"] = self.config["cache_dir"]
import transformers
hf_model_id = f"meta-llama/{self.model_id}"
model_config = transformers.AutoConfig.from_pretrained(hf_model_id)