Fixes #17; Disallow automatic major version upgrade for dependencies and update OpenAI API

This commit is contained in:
Nils Blach 2023-11-20 14:40:30 +01:00 committed by Nils Blach
parent 8f1e6ce81d
commit ac4a35ea9f
3 changed files with 34 additions and 33 deletions

View File

@ -80,12 +80,12 @@ class AbstractLanguageModel(ABC):
pass
@abstractmethod
def get_response_texts(self, query_responses: Union[List[Dict], Dict]) -> List[str]:
def get_response_texts(self, query_responses: Union[List[Any], Any]) -> List[str]:
"""
Abstract method to extract response texts from the language model's response(s).
:param query_responses: The responses returned from the language model.
:type query_responses: Union[List[Dict], Dict]
:type query_responses: Union[List[Any], Any]
:return: List of textual responses.
:rtype: List[str]
"""

View File

@ -7,11 +7,12 @@
# main author: Nils Blach
import backoff
import openai
import os
import random
import time
from typing import List, Dict, Union
from openai import OpenAI, OpenAIError
from openai.types.chat.chat_completion import ChatCompletion
from .abstract_language_model import AbstractLanguageModel
@ -53,15 +54,15 @@ class ChatGPT(AbstractLanguageModel):
self.organization: str = self.config["organization"]
if self.organization == "":
self.logger.warning("OPENAI_ORGANIZATION is not set")
else:
openai.organization = self.organization
# The api key is the api key that is used for chatgpt. Env variable OPENAI_API_KEY takes precedence over config.
self.api_key: str = os.getenv("OPENAI_API_KEY", self.config["api_key"])
if self.api_key == "":
raise ValueError("OPENAI_API_KEY is not set")
openai.api_key = self.api_key
# Initialize the OpenAI Client
self.client = OpenAI(api_key=self.api_key, organization=self.organization)
def query(self, query: str, num_responses: int = 1) -> Dict:
def query(
self, query: str, num_responses: int = 1
) -> Union[List[ChatCompletion], ChatCompletion]:
"""
Query the OpenAI model for responses.
@ -100,10 +101,8 @@ class ChatGPT(AbstractLanguageModel):
self.respone_cache[query] = response
return response
@backoff.on_exception(
backoff.expo, openai.error.OpenAIError, max_time=10, max_tries=6
)
def chat(self, messages: List[Dict], num_responses: int = 1) -> Dict:
@backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6)
def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion:
"""
Send chat messages to the OpenAI model and retrieves the model's response.
Implements backoff on OpenAI error.
@ -113,9 +112,9 @@ class ChatGPT(AbstractLanguageModel):
:param num_responses: Number of desired responses, default is 1.
:type num_responses: int
:return: The OpenAI model's response.
:rtype: Dict
:rtype: ChatCompletion
"""
response = openai.ChatCompletion.create(
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
temperature=self.temperature,
@ -124,8 +123,8 @@ class ChatGPT(AbstractLanguageModel):
stop=self.stop,
)
self.prompt_tokens += response["usage"]["prompt_tokens"]
self.completion_tokens += response["usage"]["completion_tokens"]
self.prompt_tokens += response.usage.prompt_tokens
self.completion_tokens += response.usage.completion_tokens
prompt_tokens_k = float(self.prompt_tokens) / 1000.0
completion_tokens_k = float(self.completion_tokens) / 1000.0
self.cost = (
@ -138,19 +137,21 @@ class ChatGPT(AbstractLanguageModel):
)
return response
def get_response_texts(self, query_response: Union[List[Dict], Dict]) -> List[str]:
def get_response_texts(
self, query_response: Union[List[ChatCompletion], ChatCompletion]
) -> List[str]:
"""
Extract the response texts from the query response.
:param query_response: The response dictionary (or list of dictionaries) from the OpenAI model.
:type query_response: Union[List[Dict], Dict]
:type query_response: Union[List[ChatCompletion], ChatCompletion]
:return: List of response strings.
:rtype: List[str]
"""
if isinstance(query_response, Dict):
if not isinstance(query_response, List):
query_response = [query_response]
return [
choice["message"]["content"]
choice.message.content
for response in query_response
for choice in response["choices"]
for choice in response.choices
]

View File

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "graph_of_thoughts"
version = "0.0.2"
version = "0.0.3"
authors = [
{ name="Maciej Besta", email="maciej.besta@inf.ethz.ch" },
{ name="Nils Blach", email="nils.blach@inf.ethz.ch" },
@ -20,17 +20,17 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"backoff>=2.2.1",
"openai>=0.27.7",
"matplotlib>=3.7.1",
"numpy>=1.24.3",
"pandas>=2.0.3",
"sympy>=1.12",
"torch>=2.0.1",
"transformers>=4.31.0",
"accelerate>=0.21.0",
"bitsandbytes>=0.41.0",
"scipy>=1.10.1",
"backoff>=2.2.1,<3.0.0",
"openai>=1.0.0,<2.0.0",
"matplotlib>=3.7.1,<4.0.0",
"numpy>=1.24.3,<2.0.0",
"pandas>=2.0.3,<3.0.0",
"sympy>=1.12,<2.0",
"torch>=2.0.1,<3.0.0",
"transformers>=4.31.0,<5.0.0",
"accelerate>=0.21.0,<1.0.0",
"bitsandbytes>=0.41.0,<1.0.0",
"scipy>=1.10.1,<2.0.0",
]
[project.urls]