Fixes #17; Disallow automatic major version upgrade for dependencies and update OpenAI API

This commit is contained in:
Nils Blach 2023-11-20 14:40:30 +01:00 committed by Nils Blach
parent 8f1e6ce81d
commit ac4a35ea9f
3 changed files with 34 additions and 33 deletions

View File

@ -80,12 +80,12 @@ class AbstractLanguageModel(ABC):
pass pass
@abstractmethod @abstractmethod
def get_response_texts(self, query_responses: Union[List[Dict], Dict]) -> List[str]: def get_response_texts(self, query_responses: Union[List[Any], Any]) -> List[str]:
""" """
Abstract method to extract response texts from the language model's response(s). Abstract method to extract response texts from the language model's response(s).
:param query_responses: The responses returned from the language model. :param query_responses: The responses returned from the language model.
:type query_responses: Union[List[Dict], Dict] :type query_responses: Union[List[Any], Any]
:return: List of textual responses. :return: List of textual responses.
:rtype: List[str] :rtype: List[str]
""" """

View File

@ -7,11 +7,12 @@
# main author: Nils Blach # main author: Nils Blach
import backoff import backoff
import openai
import os import os
import random import random
import time import time
from typing import List, Dict, Union from typing import List, Dict, Union
from openai import OpenAI, OpenAIError
from openai.types.chat.chat_completion import ChatCompletion
from .abstract_language_model import AbstractLanguageModel from .abstract_language_model import AbstractLanguageModel
@ -53,15 +54,15 @@ class ChatGPT(AbstractLanguageModel):
self.organization: str = self.config["organization"] self.organization: str = self.config["organization"]
if self.organization == "": if self.organization == "":
self.logger.warning("OPENAI_ORGANIZATION is not set") self.logger.warning("OPENAI_ORGANIZATION is not set")
else:
openai.organization = self.organization
# The api key is the api key that is used for chatgpt. Env variable OPENAI_API_KEY takes precedence over config.
self.api_key: str = os.getenv("OPENAI_API_KEY", self.config["api_key"]) self.api_key: str = os.getenv("OPENAI_API_KEY", self.config["api_key"])
if self.api_key == "": if self.api_key == "":
raise ValueError("OPENAI_API_KEY is not set") raise ValueError("OPENAI_API_KEY is not set")
openai.api_key = self.api_key # Initialize the OpenAI Client
self.client = OpenAI(api_key=self.api_key, organization=self.organization)
def query(self, query: str, num_responses: int = 1) -> Dict: def query(
self, query: str, num_responses: int = 1
) -> Union[List[ChatCompletion], ChatCompletion]:
""" """
Query the OpenAI model for responses. Query the OpenAI model for responses.
@ -100,10 +101,8 @@ class ChatGPT(AbstractLanguageModel):
self.respone_cache[query] = response self.respone_cache[query] = response
return response return response
@backoff.on_exception( @backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6)
backoff.expo, openai.error.OpenAIError, max_time=10, max_tries=6 def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion:
)
def chat(self, messages: List[Dict], num_responses: int = 1) -> Dict:
""" """
Send chat messages to the OpenAI model and retrieves the model's response. Send chat messages to the OpenAI model and retrieves the model's response.
Implements backoff on OpenAI error. Implements backoff on OpenAI error.
@ -113,9 +112,9 @@ class ChatGPT(AbstractLanguageModel):
:param num_responses: Number of desired responses, default is 1. :param num_responses: Number of desired responses, default is 1.
:type num_responses: int :type num_responses: int
:return: The OpenAI model's response. :return: The OpenAI model's response.
:rtype: Dict :rtype: ChatCompletion
""" """
response = openai.ChatCompletion.create( response = self.client.chat.completions.create(
model=self.model_id, model=self.model_id,
messages=messages, messages=messages,
temperature=self.temperature, temperature=self.temperature,
@ -124,8 +123,8 @@ class ChatGPT(AbstractLanguageModel):
stop=self.stop, stop=self.stop,
) )
self.prompt_tokens += response["usage"]["prompt_tokens"] self.prompt_tokens += response.usage.prompt_tokens
self.completion_tokens += response["usage"]["completion_tokens"] self.completion_tokens += response.usage.completion_tokens
prompt_tokens_k = float(self.prompt_tokens) / 1000.0 prompt_tokens_k = float(self.prompt_tokens) / 1000.0
completion_tokens_k = float(self.completion_tokens) / 1000.0 completion_tokens_k = float(self.completion_tokens) / 1000.0
self.cost = ( self.cost = (
@ -138,19 +137,21 @@ class ChatGPT(AbstractLanguageModel):
) )
return response return response
def get_response_texts(self, query_response: Union[List[Dict], Dict]) -> List[str]: def get_response_texts(
self, query_response: Union[List[ChatCompletion], ChatCompletion]
) -> List[str]:
""" """
Extract the response texts from the query response. Extract the response texts from the query response.
:param query_response: The response dictionary (or list of dictionaries) from the OpenAI model. :param query_response: The response dictionary (or list of dictionaries) from the OpenAI model.
:type query_response: Union[List[Dict], Dict] :type query_response: Union[List[ChatCompletion], ChatCompletion]
:return: List of response strings. :return: List of response strings.
:rtype: List[str] :rtype: List[str]
""" """
if isinstance(query_response, Dict): if not isinstance(query_response, List):
query_response = [query_response] query_response = [query_response]
return [ return [
choice["message"]["content"] choice.message.content
for response in query_response for response in query_response
for choice in response["choices"] for choice in response.choices
] ]

View File

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "graph_of_thoughts" name = "graph_of_thoughts"
version = "0.0.2" version = "0.0.3"
authors = [ authors = [
{ name="Maciej Besta", email="maciej.besta@inf.ethz.ch" }, { name="Maciej Besta", email="maciej.besta@inf.ethz.ch" },
{ name="Nils Blach", email="nils.blach@inf.ethz.ch" }, { name="Nils Blach", email="nils.blach@inf.ethz.ch" },
@ -20,17 +20,17 @@ classifiers = [
"Operating System :: OS Independent", "Operating System :: OS Independent",
] ]
dependencies = [ dependencies = [
"backoff>=2.2.1", "backoff>=2.2.1,<3.0.0",
"openai>=0.27.7", "openai>=1.0.0,<2.0.0",
"matplotlib>=3.7.1", "matplotlib>=3.7.1,<4.0.0",
"numpy>=1.24.3", "numpy>=1.24.3,<2.0.0",
"pandas>=2.0.3", "pandas>=2.0.3,<3.0.0",
"sympy>=1.12", "sympy>=1.12,<2.0",
"torch>=2.0.1", "torch>=2.0.1,<3.0.0",
"transformers>=4.31.0", "transformers>=4.31.0,<5.0.0",
"accelerate>=0.21.0", "accelerate>=0.21.0,<1.0.0",
"bitsandbytes>=0.41.0", "bitsandbytes>=0.41.0,<1.0.0",
"scipy>=1.10.1", "scipy>=1.10.1,<2.0.0",
] ]
[project.urls] [project.urls]