Fix importing transformers library before setting of env variables
This commit is contained in:
parent
b41d653f21
commit
8590e4bd65
@ -8,7 +8,6 @@
|
||||
|
||||
import os
|
||||
import torch
|
||||
import transformers
|
||||
from typing import List, Dict, Union
|
||||
from .abstract_language_model import AbstractLanguageModel
|
||||
|
||||
@ -48,6 +47,7 @@ class Llama2HF(AbstractLanguageModel):
|
||||
|
||||
# Important: must be done before importing transformers
|
||||
os.environ["TRANSFORMERS_CACHE"] = self.config["cache_dir"]
|
||||
import transformers
|
||||
|
||||
hf_model_id = f"meta-llama/{self.model_id}"
|
||||
model_config = transformers.AutoConfig.from_pretrained(hf_model_id)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user