diff --git a/graph_of_thoughts/language_models/abstract_language_model.py b/graph_of_thoughts/language_models/abstract_language_model.py index a066eaf..cead63c 100644 --- a/graph_of_thoughts/language_models/abstract_language_model.py +++ b/graph_of_thoughts/language_models/abstract_language_model.py @@ -80,12 +80,12 @@ class AbstractLanguageModel(ABC): pass @abstractmethod - def get_response_texts(self, query_responses: Union[List[Dict], Dict]) -> List[str]: + def get_response_texts(self, query_responses: Union[List[Any], Any]) -> List[str]: """ Abstract method to extract response texts from the language model's response(s). :param query_responses: The responses returned from the language model. - :type query_responses: Union[List[Dict], Dict] + :type query_responses: Union[List[Any], Any] :return: List of textual responses. :rtype: List[str] """ diff --git a/graph_of_thoughts/language_models/chatgpt.py b/graph_of_thoughts/language_models/chatgpt.py index 52da92a..4f63d61 100644 --- a/graph_of_thoughts/language_models/chatgpt.py +++ b/graph_of_thoughts/language_models/chatgpt.py @@ -7,11 +7,12 @@ # main author: Nils Blach import backoff -import openai import os import random import time from typing import List, Dict, Union +from openai import OpenAI, OpenAIError +from openai.types.chat.chat_completion import ChatCompletion from .abstract_language_model import AbstractLanguageModel @@ -53,15 +54,15 @@ class ChatGPT(AbstractLanguageModel): self.organization: str = self.config["organization"] if self.organization == "": self.logger.warning("OPENAI_ORGANIZATION is not set") - else: - openai.organization = self.organization - # The api key is the api key that is used for chatgpt. Env variable OPENAI_API_KEY takes precedence over config. self.api_key: str = os.getenv("OPENAI_API_KEY", self.config["api_key"]) if self.api_key == "": raise ValueError("OPENAI_API_KEY is not set") - openai.api_key = self.api_key + # Initialize the OpenAI Client + self.client = OpenAI(api_key=self.api_key, organization=self.organization) - def query(self, query: str, num_responses: int = 1) -> Dict: + def query( + self, query: str, num_responses: int = 1 + ) -> Union[List[ChatCompletion], ChatCompletion]: """ Query the OpenAI model for responses. @@ -100,10 +101,8 @@ class ChatGPT(AbstractLanguageModel): self.respone_cache[query] = response return response - @backoff.on_exception( - backoff.expo, openai.error.OpenAIError, max_time=10, max_tries=6 - ) - def chat(self, messages: List[Dict], num_responses: int = 1) -> Dict: + @backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6) + def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion: """ Send chat messages to the OpenAI model and retrieves the model's response. Implements backoff on OpenAI error. @@ -113,9 +112,9 @@ class ChatGPT(AbstractLanguageModel): :param num_responses: Number of desired responses, default is 1. :type num_responses: int :return: The OpenAI model's response. - :rtype: Dict + :rtype: ChatCompletion """ - response = openai.ChatCompletion.create( + response = self.client.chat.completions.create( model=self.model_id, messages=messages, temperature=self.temperature, @@ -124,8 +123,8 @@ class ChatGPT(AbstractLanguageModel): stop=self.stop, ) - self.prompt_tokens += response["usage"]["prompt_tokens"] - self.completion_tokens += response["usage"]["completion_tokens"] + self.prompt_tokens += response.usage.prompt_tokens + self.completion_tokens += response.usage.completion_tokens prompt_tokens_k = float(self.prompt_tokens) / 1000.0 completion_tokens_k = float(self.completion_tokens) / 1000.0 self.cost = ( @@ -138,19 +137,21 @@ class ChatGPT(AbstractLanguageModel): ) return response - def get_response_texts(self, query_response: Union[List[Dict], Dict]) -> List[str]: + def get_response_texts( + self, query_response: Union[List[ChatCompletion], ChatCompletion] + ) -> List[str]: """ Extract the response texts from the query response. :param query_response: The response dictionary (or list of dictionaries) from the OpenAI model. - :type query_response: Union[List[Dict], Dict] + :type query_response: Union[List[ChatCompletion], ChatCompletion] :return: List of response strings. :rtype: List[str] """ - if isinstance(query_response, Dict): + if not isinstance(query_response, List): query_response = [query_response] return [ - choice["message"]["content"] + choice.message.content for response in query_response - for choice in response["choices"] + for choice in response.choices ] diff --git a/pyproject.toml b/pyproject.toml index e41f145..ecbf97c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "graph_of_thoughts" -version = "0.0.2" +version = "0.0.3" authors = [ { name="Maciej Besta", email="maciej.besta@inf.ethz.ch" }, { name="Nils Blach", email="nils.blach@inf.ethz.ch" }, @@ -20,17 +20,17 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "backoff>=2.2.1", - "openai>=0.27.7", - "matplotlib>=3.7.1", - "numpy>=1.24.3", - "pandas>=2.0.3", - "sympy>=1.12", - "torch>=2.0.1", - "transformers>=4.31.0", - "accelerate>=0.21.0", - "bitsandbytes>=0.41.0", - "scipy>=1.10.1", + "backoff>=2.2.1,<3.0.0", + "openai>=1.0.0,<2.0.0", + "matplotlib>=3.7.1,<4.0.0", + "numpy>=1.24.3,<2.0.0", + "pandas>=2.0.3,<3.0.0", + "sympy>=1.12,<2.0", + "torch>=2.0.1,<3.0.0", + "transformers>=4.31.0,<5.0.0", + "accelerate>=0.21.0,<1.0.0", + "bitsandbytes>=0.41.0,<1.0.0", + "scipy>=1.10.1,<2.0.0", ] [project.urls]