2024-10-11 10:14:03 +02:00

120 lines
4.6 KiB
Python

# Copyright (c) 2023 ETH Zurich.
# All rights reserved.
#
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
#
# main author: Ales Kubicek
import os
import torch
from typing import List, Dict, Union
from .abstract_language_model import AbstractLanguageModel
class Llama2HF(AbstractLanguageModel):
"""
An interface to use LLaMA 2 models through the HuggingFace library.
"""
def __init__(
self, config_path: str = "", model_name: str = "llama7b-hf", cache: bool = False
) -> None:
"""
Initialize an instance of the Llama2HF class with configuration, model details, and caching options.
:param config_path: Path to the configuration file. Defaults to an empty string.
:type config_path: str
:param model_name: Specifies the name of the LLaMA model variant. Defaults to "llama7b-hf".
Used to select the correct configuration.
:type model_name: str
:param cache: Flag to determine whether to cache responses. Defaults to False.
:type cache: bool
"""
super().__init__(config_path, model_name, cache)
self.config: Dict = self.config[model_name]
# Detailed id of the used model.
self.model_id: str = self.config["model_id"]
# Costs for 1000 tokens.
self.prompt_token_cost: float = self.config["prompt_token_cost"]
self.response_token_cost: float = self.config["response_token_cost"]
# The temperature is defined as the randomness of the model's output.
self.temperature: float = self.config["temperature"]
# Top K sampling.
self.top_k: int = self.config["top_k"]
# The maximum number of tokens to generate in the chat completion.
self.max_tokens: int = self.config["max_tokens"]
# Important: must be done before importing transformers
os.environ["TRANSFORMERS_CACHE"] = self.config["cache_dir"]
import transformers
hf_model_id = f"meta-llama/{self.model_id}"
model_config = transformers.AutoConfig.from_pretrained(hf_model_id)
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model_id)
self.model = transformers.AutoModelForCausalLM.from_pretrained(
hf_model_id,
trust_remote_code=True,
config=model_config,
quantization_config=bnb_config,
device_map="auto",
)
self.model.eval()
torch.no_grad()
self.generate_text = transformers.pipeline(
model=self.model, tokenizer=self.tokenizer, task="text-generation"
)
def query(self, query: str, num_responses: int = 1) -> List[Dict]:
"""
Query the LLaMA 2 model for responses.
:param query: The query to be posed to the language model.
:type query: str
:param num_responses: Number of desired responses, default is 1.
:type num_responses: int
:return: Response(s) from the LLaMA 2 model.
:rtype: List[Dict]
"""
if self.cache and query in self.response_cache:
return self.response_cache[query]
sequences = []
query = f"<s><<SYS>>You are a helpful assistant. Always follow the intstructions precisely and output the response exactly in the requested format.<</SYS>>\n\n[INST] {query} [/INST]"
for _ in range(num_responses):
sequences.extend(
self.generate_text(
query,
do_sample=True,
top_k=self.top_k,
num_return_sequences=1,
eos_token_id=self.tokenizer.eos_token_id,
max_length=self.max_tokens,
)
)
response = [
{"generated_text": sequence["generated_text"][len(query) :].strip()}
for sequence in sequences
]
if self.cache:
self.response_cache[query] = response
return response
def get_response_texts(self, query_responses: List[Dict]) -> List[str]:
"""
Extract the response texts from the query response.
:param query_responses: The response list of dictionaries generated from the `query` method.
:type query_responses: List[Dict]
:return: List of response strings.
:rtype: List[str]
"""
return [query_response["generated_text"] for query_response in query_responses]