fix typo: respone_cache -> response_cache (#33)

This commit is contained in:
Robert Gerstenberger 2024-10-11 10:14:03 +02:00 committed by GitHub
parent 15fb8e661d
commit a939a4577c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 8 deletions

View File

@ -36,7 +36,7 @@ class AbstractLanguageModel(ABC):
self.model_name: str = model_name self.model_name: str = model_name
self.cache = cache self.cache = cache
if self.cache: if self.cache:
self.respone_cache: Dict[str, List[Any]] = {} self.response_cache: Dict[str, List[Any]] = {}
self.load_config(config_path) self.load_config(config_path)
self.prompt_tokens: int = 0 self.prompt_tokens: int = 0
self.completion_tokens: int = 0 self.completion_tokens: int = 0
@ -63,7 +63,7 @@ class AbstractLanguageModel(ABC):
""" """
Clear the response cache. Clear the response cache.
""" """
self.respone_cache.clear() self.response_cache.clear()
@abstractmethod @abstractmethod
def query(self, query: str, num_responses: int = 1) -> Any: def query(self, query: str, num_responses: int = 1) -> Any:

View File

@ -73,8 +73,8 @@ class ChatGPT(AbstractLanguageModel):
:return: Response(s) from the OpenAI model. :return: Response(s) from the OpenAI model.
:rtype: Dict :rtype: Dict
""" """
if self.cache and query in self.respone_cache: if self.cache and query in self.response_cache:
return self.respone_cache[query] return self.response_cache[query]
if num_responses == 1: if num_responses == 1:
response = self.chat([{"role": "user", "content": query}], num_responses) response = self.chat([{"role": "user", "content": query}], num_responses)
@ -98,7 +98,7 @@ class ChatGPT(AbstractLanguageModel):
total_num_attempts -= 1 total_num_attempts -= 1
if self.cache: if self.cache:
self.respone_cache[query] = response self.response_cache[query] = response
return response return response
@backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6) @backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6)

View File

@ -84,8 +84,8 @@ class Llama2HF(AbstractLanguageModel):
:return: Response(s) from the LLaMA 2 model. :return: Response(s) from the LLaMA 2 model.
:rtype: List[Dict] :rtype: List[Dict]
""" """
if self.cache and query in self.respone_cache: if self.cache and query in self.response_cache:
return self.respone_cache[query] return self.response_cache[query]
sequences = [] sequences = []
query = f"<s><<SYS>>You are a helpful assistant. Always follow the intstructions precisely and output the response exactly in the requested format.<</SYS>>\n\n[INST] {query} [/INST]" query = f"<s><<SYS>>You are a helpful assistant. Always follow the intstructions precisely and output the response exactly in the requested format.<</SYS>>\n\n[INST] {query} [/INST]"
for _ in range(num_responses): for _ in range(num_responses):
@ -104,7 +104,7 @@ class Llama2HF(AbstractLanguageModel):
for sequence in sequences for sequence in sequences
] ]
if self.cache: if self.cache:
self.respone_cache[query] = response self.response_cache[query] = response
return response return response
def get_response_texts(self, query_responses: List[Dict]) -> List[str]: def get_response_texts(self, query_responses: List[Dict]) -> List[str]: