93 lines
2.9 KiB
Python
93 lines
2.9 KiB
Python
# Copyright (c) 2023 ETH Zurich.
|
|
# All rights reserved.
|
|
#
|
|
# Use of this source code is governed by a BSD-style license that can be
|
|
# found in the LICENSE file.
|
|
#
|
|
# main author: Nils Blach
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Dict, Union, Any
|
|
import json
|
|
import os
|
|
import logging
|
|
|
|
|
|
class AbstractLanguageModel(ABC):
|
|
"""
|
|
Abstract base class that defines the interface for all language models.
|
|
"""
|
|
|
|
def __init__(
|
|
self, config_path: str = "", model_name: str = "", cache: bool = False
|
|
) -> None:
|
|
"""
|
|
Initialize the AbstractLanguageModel instance with configuration, model details, and caching options.
|
|
|
|
:param config_path: Path to the config file. Defaults to "".
|
|
:type config_path: str
|
|
:param model_name: Name of the language model. Defaults to "".
|
|
:type model_name: str
|
|
:param cache: Flag to determine whether to cache responses. Defaults to False.
|
|
:type cache: bool
|
|
"""
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
self.config: Dict = None
|
|
self.model_name: str = model_name
|
|
self.cache = cache
|
|
if self.cache:
|
|
self.response_cache: Dict[str, List[Any]] = {}
|
|
self.load_config(config_path)
|
|
self.prompt_tokens: int = 0
|
|
self.completion_tokens: int = 0
|
|
self.cost: float = 0.0
|
|
|
|
def load_config(self, path: str) -> None:
|
|
"""
|
|
Load configuration from a specified path.
|
|
|
|
:param path: Path to the config file. If an empty path provided,
|
|
default is `config.json` in the current directory.
|
|
:type path: str
|
|
"""
|
|
if path == "":
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
path = os.path.join(current_dir, "config.json")
|
|
|
|
with open(path, "r") as f:
|
|
self.config = json.load(f)
|
|
|
|
self.logger.debug(f"Loaded config from {path} for {self.model_name}")
|
|
|
|
def clear_cache(self) -> None:
|
|
"""
|
|
Clear the response cache.
|
|
"""
|
|
self.response_cache.clear()
|
|
|
|
@abstractmethod
|
|
def query(self, query: str, num_responses: int = 1) -> Any:
|
|
"""
|
|
Abstract method to query the language model.
|
|
|
|
:param query: The query to be posed to the language model.
|
|
:type query: str
|
|
:param num_responses: The number of desired responses.
|
|
:type num_responses: int
|
|
:return: The language model's response(s).
|
|
:rtype: Any
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_response_texts(self, query_responses: Union[List[Any], Any]) -> List[str]:
|
|
"""
|
|
Abstract method to extract response texts from the language model's response(s).
|
|
|
|
:param query_responses: The responses returned from the language model.
|
|
:type query_responses: Union[List[Any], Any]
|
|
:return: List of textual responses.
|
|
:rtype: List[str]
|
|
"""
|
|
pass
|