fix typo: respone_cache -> response_cache (#33)
This commit is contained in:
parent
15fb8e661d
commit
a939a4577c
@ -36,7 +36,7 @@ class AbstractLanguageModel(ABC):
|
|||||||
self.model_name: str = model_name
|
self.model_name: str = model_name
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
if self.cache:
|
if self.cache:
|
||||||
self.respone_cache: Dict[str, List[Any]] = {}
|
self.response_cache: Dict[str, List[Any]] = {}
|
||||||
self.load_config(config_path)
|
self.load_config(config_path)
|
||||||
self.prompt_tokens: int = 0
|
self.prompt_tokens: int = 0
|
||||||
self.completion_tokens: int = 0
|
self.completion_tokens: int = 0
|
||||||
@ -63,7 +63,7 @@ class AbstractLanguageModel(ABC):
|
|||||||
"""
|
"""
|
||||||
Clear the response cache.
|
Clear the response cache.
|
||||||
"""
|
"""
|
||||||
self.respone_cache.clear()
|
self.response_cache.clear()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def query(self, query: str, num_responses: int = 1) -> Any:
|
def query(self, query: str, num_responses: int = 1) -> Any:
|
||||||
|
|||||||
@ -73,8 +73,8 @@ class ChatGPT(AbstractLanguageModel):
|
|||||||
:return: Response(s) from the OpenAI model.
|
:return: Response(s) from the OpenAI model.
|
||||||
:rtype: Dict
|
:rtype: Dict
|
||||||
"""
|
"""
|
||||||
if self.cache and query in self.respone_cache:
|
if self.cache and query in self.response_cache:
|
||||||
return self.respone_cache[query]
|
return self.response_cache[query]
|
||||||
|
|
||||||
if num_responses == 1:
|
if num_responses == 1:
|
||||||
response = self.chat([{"role": "user", "content": query}], num_responses)
|
response = self.chat([{"role": "user", "content": query}], num_responses)
|
||||||
@ -98,7 +98,7 @@ class ChatGPT(AbstractLanguageModel):
|
|||||||
total_num_attempts -= 1
|
total_num_attempts -= 1
|
||||||
|
|
||||||
if self.cache:
|
if self.cache:
|
||||||
self.respone_cache[query] = response
|
self.response_cache[query] = response
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6)
|
@backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6)
|
||||||
|
|||||||
@ -84,8 +84,8 @@ class Llama2HF(AbstractLanguageModel):
|
|||||||
:return: Response(s) from the LLaMA 2 model.
|
:return: Response(s) from the LLaMA 2 model.
|
||||||
:rtype: List[Dict]
|
:rtype: List[Dict]
|
||||||
"""
|
"""
|
||||||
if self.cache and query in self.respone_cache:
|
if self.cache and query in self.response_cache:
|
||||||
return self.respone_cache[query]
|
return self.response_cache[query]
|
||||||
sequences = []
|
sequences = []
|
||||||
query = f"<s><<SYS>>You are a helpful assistant. Always follow the intstructions precisely and output the response exactly in the requested format.<</SYS>>\n\n[INST] {query} [/INST]"
|
query = f"<s><<SYS>>You are a helpful assistant. Always follow the intstructions precisely and output the response exactly in the requested format.<</SYS>>\n\n[INST] {query} [/INST]"
|
||||||
for _ in range(num_responses):
|
for _ in range(num_responses):
|
||||||
@ -104,7 +104,7 @@ class Llama2HF(AbstractLanguageModel):
|
|||||||
for sequence in sequences
|
for sequence in sequences
|
||||||
]
|
]
|
||||||
if self.cache:
|
if self.cache:
|
||||||
self.respone_cache[query] = response
|
self.response_cache[query] = response
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def get_response_texts(self, query_responses: List[Dict]) -> List[str]:
|
def get_response_texts(self, query_responses: List[Dict]) -> List[str]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user