Fixes #17; Disallow automatic major version upgrade for dependencies and update OpenAI API
This commit is contained in:
parent
8f1e6ce81d
commit
ac4a35ea9f
@ -80,12 +80,12 @@ class AbstractLanguageModel(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_response_texts(self, query_responses: Union[List[Dict], Dict]) -> List[str]:
|
def get_response_texts(self, query_responses: Union[List[Any], Any]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Abstract method to extract response texts from the language model's response(s).
|
Abstract method to extract response texts from the language model's response(s).
|
||||||
|
|
||||||
:param query_responses: The responses returned from the language model.
|
:param query_responses: The responses returned from the language model.
|
||||||
:type query_responses: Union[List[Dict], Dict]
|
:type query_responses: Union[List[Any], Any]
|
||||||
:return: List of textual responses.
|
:return: List of textual responses.
|
||||||
:rtype: List[str]
|
:rtype: List[str]
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -7,11 +7,12 @@
|
|||||||
# main author: Nils Blach
|
# main author: Nils Blach
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
import openai
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Union
|
from typing import List, Dict, Union
|
||||||
|
from openai import OpenAI, OpenAIError
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
from .abstract_language_model import AbstractLanguageModel
|
from .abstract_language_model import AbstractLanguageModel
|
||||||
|
|
||||||
@ -53,15 +54,15 @@ class ChatGPT(AbstractLanguageModel):
|
|||||||
self.organization: str = self.config["organization"]
|
self.organization: str = self.config["organization"]
|
||||||
if self.organization == "":
|
if self.organization == "":
|
||||||
self.logger.warning("OPENAI_ORGANIZATION is not set")
|
self.logger.warning("OPENAI_ORGANIZATION is not set")
|
||||||
else:
|
|
||||||
openai.organization = self.organization
|
|
||||||
# The api key is the api key that is used for chatgpt. Env variable OPENAI_API_KEY takes precedence over config.
|
|
||||||
self.api_key: str = os.getenv("OPENAI_API_KEY", self.config["api_key"])
|
self.api_key: str = os.getenv("OPENAI_API_KEY", self.config["api_key"])
|
||||||
if self.api_key == "":
|
if self.api_key == "":
|
||||||
raise ValueError("OPENAI_API_KEY is not set")
|
raise ValueError("OPENAI_API_KEY is not set")
|
||||||
openai.api_key = self.api_key
|
# Initialize the OpenAI Client
|
||||||
|
self.client = OpenAI(api_key=self.api_key, organization=self.organization)
|
||||||
|
|
||||||
def query(self, query: str, num_responses: int = 1) -> Dict:
|
def query(
|
||||||
|
self, query: str, num_responses: int = 1
|
||||||
|
) -> Union[List[ChatCompletion], ChatCompletion]:
|
||||||
"""
|
"""
|
||||||
Query the OpenAI model for responses.
|
Query the OpenAI model for responses.
|
||||||
|
|
||||||
@ -100,10 +101,8 @@ class ChatGPT(AbstractLanguageModel):
|
|||||||
self.respone_cache[query] = response
|
self.respone_cache[query] = response
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@backoff.on_exception(
|
@backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6)
|
||||||
backoff.expo, openai.error.OpenAIError, max_time=10, max_tries=6
|
def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion:
|
||||||
)
|
|
||||||
def chat(self, messages: List[Dict], num_responses: int = 1) -> Dict:
|
|
||||||
"""
|
"""
|
||||||
Send chat messages to the OpenAI model and retrieves the model's response.
|
Send chat messages to the OpenAI model and retrieves the model's response.
|
||||||
Implements backoff on OpenAI error.
|
Implements backoff on OpenAI error.
|
||||||
@ -113,9 +112,9 @@ class ChatGPT(AbstractLanguageModel):
|
|||||||
:param num_responses: Number of desired responses, default is 1.
|
:param num_responses: Number of desired responses, default is 1.
|
||||||
:type num_responses: int
|
:type num_responses: int
|
||||||
:return: The OpenAI model's response.
|
:return: The OpenAI model's response.
|
||||||
:rtype: Dict
|
:rtype: ChatCompletion
|
||||||
"""
|
"""
|
||||||
response = openai.ChatCompletion.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_id,
|
model=self.model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
@ -124,8 +123,8 @@ class ChatGPT(AbstractLanguageModel):
|
|||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prompt_tokens += response["usage"]["prompt_tokens"]
|
self.prompt_tokens += response.usage.prompt_tokens
|
||||||
self.completion_tokens += response["usage"]["completion_tokens"]
|
self.completion_tokens += response.usage.completion_tokens
|
||||||
prompt_tokens_k = float(self.prompt_tokens) / 1000.0
|
prompt_tokens_k = float(self.prompt_tokens) / 1000.0
|
||||||
completion_tokens_k = float(self.completion_tokens) / 1000.0
|
completion_tokens_k = float(self.completion_tokens) / 1000.0
|
||||||
self.cost = (
|
self.cost = (
|
||||||
@ -138,19 +137,21 @@ class ChatGPT(AbstractLanguageModel):
|
|||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def get_response_texts(self, query_response: Union[List[Dict], Dict]) -> List[str]:
|
def get_response_texts(
|
||||||
|
self, query_response: Union[List[ChatCompletion], ChatCompletion]
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Extract the response texts from the query response.
|
Extract the response texts from the query response.
|
||||||
|
|
||||||
:param query_response: The response dictionary (or list of dictionaries) from the OpenAI model.
|
:param query_response: The response dictionary (or list of dictionaries) from the OpenAI model.
|
||||||
:type query_response: Union[List[Dict], Dict]
|
:type query_response: Union[List[ChatCompletion], ChatCompletion]
|
||||||
:return: List of response strings.
|
:return: List of response strings.
|
||||||
:rtype: List[str]
|
:rtype: List[str]
|
||||||
"""
|
"""
|
||||||
if isinstance(query_response, Dict):
|
if not isinstance(query_response, List):
|
||||||
query_response = [query_response]
|
query_response = [query_response]
|
||||||
return [
|
return [
|
||||||
choice["message"]["content"]
|
choice.message.content
|
||||||
for response in query_response
|
for response in query_response
|
||||||
for choice in response["choices"]
|
for choice in response.choices
|
||||||
]
|
]
|
||||||
|
|||||||
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "graph_of_thoughts"
|
name = "graph_of_thoughts"
|
||||||
version = "0.0.2"
|
version = "0.0.3"
|
||||||
authors = [
|
authors = [
|
||||||
{ name="Maciej Besta", email="maciej.besta@inf.ethz.ch" },
|
{ name="Maciej Besta", email="maciej.besta@inf.ethz.ch" },
|
||||||
{ name="Nils Blach", email="nils.blach@inf.ethz.ch" },
|
{ name="Nils Blach", email="nils.blach@inf.ethz.ch" },
|
||||||
@ -20,17 +20,17 @@ classifiers = [
|
|||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"backoff>=2.2.1",
|
"backoff>=2.2.1,<3.0.0",
|
||||||
"openai>=0.27.7",
|
"openai>=1.0.0,<2.0.0",
|
||||||
"matplotlib>=3.7.1",
|
"matplotlib>=3.7.1,<4.0.0",
|
||||||
"numpy>=1.24.3",
|
"numpy>=1.24.3,<2.0.0",
|
||||||
"pandas>=2.0.3",
|
"pandas>=2.0.3,<3.0.0",
|
||||||
"sympy>=1.12",
|
"sympy>=1.12,<2.0",
|
||||||
"torch>=2.0.1",
|
"torch>=2.0.1,<3.0.0",
|
||||||
"transformers>=4.31.0",
|
"transformers>=4.31.0,<5.0.0",
|
||||||
"accelerate>=0.21.0",
|
"accelerate>=0.21.0,<1.0.0",
|
||||||
"bitsandbytes>=0.41.0",
|
"bitsandbytes>=0.41.0,<1.0.0",
|
||||||
"scipy>=1.10.1",
|
"scipy>=1.10.1,<2.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user