Fixes #17; Disallow automatic major version upgrade for dependencies and update OpenAI API
This commit is contained in:
parent
8f1e6ce81d
commit
ac4a35ea9f
@ -80,12 +80,12 @@ class AbstractLanguageModel(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_response_texts(self, query_responses: Union[List[Dict], Dict]) -> List[str]:
|
||||
def get_response_texts(self, query_responses: Union[List[Any], Any]) -> List[str]:
|
||||
"""
|
||||
Abstract method to extract response texts from the language model's response(s).
|
||||
|
||||
:param query_responses: The responses returned from the language model.
|
||||
:type query_responses: Union[List[Dict], Dict]
|
||||
:type query_responses: Union[List[Any], Any]
|
||||
:return: List of textual responses.
|
||||
:rtype: List[str]
|
||||
"""
|
||||
|
||||
@ -7,11 +7,12 @@
|
||||
# main author: Nils Blach
|
||||
|
||||
import backoff
|
||||
import openai
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import List, Dict, Union
|
||||
from openai import OpenAI, OpenAIError
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from .abstract_language_model import AbstractLanguageModel
|
||||
|
||||
@ -53,15 +54,15 @@ class ChatGPT(AbstractLanguageModel):
|
||||
self.organization: str = self.config["organization"]
|
||||
if self.organization == "":
|
||||
self.logger.warning("OPENAI_ORGANIZATION is not set")
|
||||
else:
|
||||
openai.organization = self.organization
|
||||
# The api key is the api key that is used for chatgpt. Env variable OPENAI_API_KEY takes precedence over config.
|
||||
self.api_key: str = os.getenv("OPENAI_API_KEY", self.config["api_key"])
|
||||
if self.api_key == "":
|
||||
raise ValueError("OPENAI_API_KEY is not set")
|
||||
openai.api_key = self.api_key
|
||||
# Initialize the OpenAI Client
|
||||
self.client = OpenAI(api_key=self.api_key, organization=self.organization)
|
||||
|
||||
def query(self, query: str, num_responses: int = 1) -> Dict:
|
||||
def query(
|
||||
self, query: str, num_responses: int = 1
|
||||
) -> Union[List[ChatCompletion], ChatCompletion]:
|
||||
"""
|
||||
Query the OpenAI model for responses.
|
||||
|
||||
@ -100,10 +101,8 @@ class ChatGPT(AbstractLanguageModel):
|
||||
self.respone_cache[query] = response
|
||||
return response
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, openai.error.OpenAIError, max_time=10, max_tries=6
|
||||
)
|
||||
def chat(self, messages: List[Dict], num_responses: int = 1) -> Dict:
|
||||
@backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6)
|
||||
def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion:
|
||||
"""
|
||||
Send chat messages to the OpenAI model and retrieves the model's response.
|
||||
Implements backoff on OpenAI error.
|
||||
@ -113,9 +112,9 @@ class ChatGPT(AbstractLanguageModel):
|
||||
:param num_responses: Number of desired responses, default is 1.
|
||||
:type num_responses: int
|
||||
:return: The OpenAI model's response.
|
||||
:rtype: Dict
|
||||
:rtype: ChatCompletion
|
||||
"""
|
||||
response = openai.ChatCompletion.create(
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
temperature=self.temperature,
|
||||
@ -124,8 +123,8 @@ class ChatGPT(AbstractLanguageModel):
|
||||
stop=self.stop,
|
||||
)
|
||||
|
||||
self.prompt_tokens += response["usage"]["prompt_tokens"]
|
||||
self.completion_tokens += response["usage"]["completion_tokens"]
|
||||
self.prompt_tokens += response.usage.prompt_tokens
|
||||
self.completion_tokens += response.usage.completion_tokens
|
||||
prompt_tokens_k = float(self.prompt_tokens) / 1000.0
|
||||
completion_tokens_k = float(self.completion_tokens) / 1000.0
|
||||
self.cost = (
|
||||
@ -138,19 +137,21 @@ class ChatGPT(AbstractLanguageModel):
|
||||
)
|
||||
return response
|
||||
|
||||
def get_response_texts(self, query_response: Union[List[Dict], Dict]) -> List[str]:
|
||||
def get_response_texts(
|
||||
self, query_response: Union[List[ChatCompletion], ChatCompletion]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Extract the response texts from the query response.
|
||||
|
||||
:param query_response: The response dictionary (or list of dictionaries) from the OpenAI model.
|
||||
:type query_response: Union[List[Dict], Dict]
|
||||
:type query_response: Union[List[ChatCompletion], ChatCompletion]
|
||||
:return: List of response strings.
|
||||
:rtype: List[str]
|
||||
"""
|
||||
if isinstance(query_response, Dict):
|
||||
if not isinstance(query_response, List):
|
||||
query_response = [query_response]
|
||||
return [
|
||||
choice["message"]["content"]
|
||||
choice.message.content
|
||||
for response in query_response
|
||||
for choice in response["choices"]
|
||||
for choice in response.choices
|
||||
]
|
||||
|
||||
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "graph_of_thoughts"
|
||||
version = "0.0.2"
|
||||
version = "0.0.3"
|
||||
authors = [
|
||||
{ name="Maciej Besta", email="maciej.besta@inf.ethz.ch" },
|
||||
{ name="Nils Blach", email="nils.blach@inf.ethz.ch" },
|
||||
@ -20,17 +20,17 @@ classifiers = [
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
dependencies = [
|
||||
"backoff>=2.2.1",
|
||||
"openai>=0.27.7",
|
||||
"matplotlib>=3.7.1",
|
||||
"numpy>=1.24.3",
|
||||
"pandas>=2.0.3",
|
||||
"sympy>=1.12",
|
||||
"torch>=2.0.1",
|
||||
"transformers>=4.31.0",
|
||||
"accelerate>=0.21.0",
|
||||
"bitsandbytes>=0.41.0",
|
||||
"scipy>=1.10.1",
|
||||
"backoff>=2.2.1,<3.0.0",
|
||||
"openai>=1.0.0,<2.0.0",
|
||||
"matplotlib>=3.7.1,<4.0.0",
|
||||
"numpy>=1.24.3,<2.0.0",
|
||||
"pandas>=2.0.3,<3.0.0",
|
||||
"sympy>=1.12,<2.0",
|
||||
"torch>=2.0.1,<3.0.0",
|
||||
"transformers>=4.31.0,<5.0.0",
|
||||
"accelerate>=0.21.0,<1.0.0",
|
||||
"bitsandbytes>=0.41.0,<1.0.0",
|
||||
"scipy>=1.10.1,<2.0.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user