158 lines
6.5 KiB
Python
158 lines
6.5 KiB
Python
# Copyright (c) 2023 ETH Zurich.
|
|
# All rights reserved.
|
|
#
|
|
# Use of this source code is governed by a BSD-style license that can be
|
|
# found in the LICENSE file.
|
|
#
|
|
# main author: Nils Blach
|
|
|
|
import backoff
|
|
import os
|
|
import random
|
|
import time
|
|
from typing import List, Dict, Union
|
|
from openai import OpenAI, OpenAIError
|
|
from openai.types.chat.chat_completion import ChatCompletion
|
|
|
|
from .abstract_language_model import AbstractLanguageModel
|
|
|
|
|
|
class ChatGPT(AbstractLanguageModel):
|
|
"""
|
|
The ChatGPT class handles interactions with the OpenAI models using the provided configuration.
|
|
|
|
Inherits from the AbstractLanguageModel and implements its abstract methods.
|
|
"""
|
|
|
|
def __init__(
|
|
self, config_path: str = "", model_name: str = "chatgpt", cache: bool = False
|
|
) -> None:
|
|
"""
|
|
Initialize the ChatGPT instance with configuration, model details, and caching options.
|
|
|
|
:param config_path: Path to the configuration file. Defaults to "".
|
|
:type config_path: str
|
|
:param model_name: Name of the model, default is 'chatgpt'. Used to select the correct configuration.
|
|
:type model_name: str
|
|
:param cache: Flag to determine whether to cache responses. Defaults to False.
|
|
:type cache: bool
|
|
"""
|
|
super().__init__(config_path, model_name, cache)
|
|
self.config: Dict = self.config[model_name]
|
|
# The model_id is the id of the model that is used for chatgpt, i.e. gpt-4, gpt-3.5-turbo, etc.
|
|
self.model_id: str = self.config["model_id"]
|
|
# The prompt_token_cost and response_token_cost are the costs for 1000 prompt tokens and 1000 response tokens respectively.
|
|
self.prompt_token_cost: float = self.config["prompt_token_cost"]
|
|
self.response_token_cost: float = self.config["response_token_cost"]
|
|
# The temperature of a model is defined as the randomness of the model's output.
|
|
self.temperature: float = self.config["temperature"]
|
|
# The maximum number of tokens to generate in the chat completion.
|
|
self.max_tokens: int = self.config["max_tokens"]
|
|
# The stop sequence is a sequence of tokens that the model will stop generating at (it will not generate the stop sequence).
|
|
self.stop: Union[str, List[str]] = self.config["stop"]
|
|
# The account organization is the organization that is used for chatgpt.
|
|
self.organization: str = self.config["organization"]
|
|
if self.organization == "":
|
|
self.logger.warning("OPENAI_ORGANIZATION is not set")
|
|
self.api_key: str = os.getenv("OPENAI_API_KEY", self.config["api_key"])
|
|
if self.api_key == "":
|
|
raise ValueError("OPENAI_API_KEY is not set")
|
|
# Initialize the OpenAI Client
|
|
self.client = OpenAI(api_key=self.api_key, organization=self.organization)
|
|
|
|
def query(
|
|
self, query: str, num_responses: int = 1
|
|
) -> Union[List[ChatCompletion], ChatCompletion]:
|
|
"""
|
|
Query the OpenAI model for responses.
|
|
|
|
:param query: The query to be posed to the language model.
|
|
:type query: str
|
|
:param num_responses: Number of desired responses, default is 1.
|
|
:type num_responses: int
|
|
:return: Response(s) from the OpenAI model.
|
|
:rtype: Dict
|
|
"""
|
|
if self.cache and query in self.respone_cache:
|
|
return self.respone_cache[query]
|
|
|
|
if num_responses == 1:
|
|
response = self.chat([{"role": "user", "content": query}], num_responses)
|
|
else:
|
|
response = []
|
|
next_try = num_responses
|
|
total_num_attempts = num_responses
|
|
while num_responses > 0 and total_num_attempts > 0:
|
|
try:
|
|
assert next_try > 0
|
|
res = self.chat([{"role": "user", "content": query}], next_try)
|
|
response.append(res)
|
|
num_responses -= next_try
|
|
next_try = min(num_responses, next_try)
|
|
except Exception as e:
|
|
next_try = (next_try + 1) // 2
|
|
self.logger.warning(
|
|
f"Error in chatgpt: {e}, trying again with {next_try} samples"
|
|
)
|
|
time.sleep(random.randint(1, 3))
|
|
total_num_attempts -= 1
|
|
|
|
if self.cache:
|
|
self.respone_cache[query] = response
|
|
return response
|
|
|
|
@backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6)
|
|
def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion:
|
|
"""
|
|
Send chat messages to the OpenAI model and retrieves the model's response.
|
|
Implements backoff on OpenAI error.
|
|
|
|
:param messages: A list of message dictionaries for the chat.
|
|
:type messages: List[Dict]
|
|
:param num_responses: Number of desired responses, default is 1.
|
|
:type num_responses: int
|
|
:return: The OpenAI model's response.
|
|
:rtype: ChatCompletion
|
|
"""
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_id,
|
|
messages=messages,
|
|
temperature=self.temperature,
|
|
max_tokens=self.max_tokens,
|
|
n=num_responses,
|
|
stop=self.stop,
|
|
)
|
|
|
|
self.prompt_tokens += response.usage.prompt_tokens
|
|
self.completion_tokens += response.usage.completion_tokens
|
|
prompt_tokens_k = float(self.prompt_tokens) / 1000.0
|
|
completion_tokens_k = float(self.completion_tokens) / 1000.0
|
|
self.cost = (
|
|
self.prompt_token_cost * prompt_tokens_k
|
|
+ self.response_token_cost * completion_tokens_k
|
|
)
|
|
self.logger.info(
|
|
f"This is the response from chatgpt: {response}"
|
|
f"\nThis is the cost of the response: {self.cost}"
|
|
)
|
|
return response
|
|
|
|
def get_response_texts(
|
|
self, query_response: Union[List[ChatCompletion], ChatCompletion]
|
|
) -> List[str]:
|
|
"""
|
|
Extract the response texts from the query response.
|
|
|
|
:param query_response: The response dictionary (or list of dictionaries) from the OpenAI model.
|
|
:type query_response: Union[List[ChatCompletion], ChatCompletion]
|
|
:return: List of response strings.
|
|
:rtype: List[str]
|
|
"""
|
|
if not isinstance(query_response, List):
|
|
query_response = [query_response]
|
|
return [
|
|
choice.message.content
|
|
for response in query_response
|
|
for choice in response.choices
|
|
]
|