# Copyright (c) 2023 ETH Zurich. # All rights reserved. # # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. # # main author: Nils Blach from abc import ABC, abstractmethod from typing import List, Dict, Union, Any import json import os import logging class AbstractLanguageModel(ABC): """ Abstract base class that defines the interface for all language models. """ def __init__( self, config_path: str = "", model_name: str = "", cache: bool = False ) -> None: """ Initialize the AbstractLanguageModel instance with configuration, model details, and caching options. :param config_path: Path to the config file. Defaults to "". :type config_path: str :param model_name: Name of the language model. Defaults to "". :type model_name: str :param cache: Flag to determine whether to cache responses. Defaults to False. :type cache: bool """ self.logger = logging.getLogger(self.__class__.__name__) self.config: Dict = None self.model_name: str = model_name self.cache = cache if self.cache: self.respone_cache: Dict[str, List[Any]] = {} self.load_config(config_path) self.prompt_tokens: int = 0 self.completion_tokens: int = 0 self.cost: float = 0.0 def load_config(self, path: str) -> None: """ Load configuration from a specified path. :param path: Path to the config file. If an empty path provided, default is `config.json` in the current directory. :type path: str """ if path == "": current_dir = os.path.dirname(os.path.abspath(__file__)) path = os.path.join(current_dir, "config.json") with open(path, "r") as f: self.config = json.load(f) self.logger.debug(f"Loaded config from {path} for {self.model_name}") def clear_cache(self) -> None: """ Clear the response cache. """ self.respone_cache.clear() @abstractmethod def query(self, query: str, num_responses: int = 1) -> Any: """ Abstract method to query the language model. :param query: The query to be posed to the language model. :type query: str :param num_responses: The number of desired responses. :type num_responses: int :return: The language model's response(s). :rtype: Any """ pass @abstractmethod def get_response_texts(self, query_responses: Union[List[Any], Any]) -> List[str]: """ Abstract method to extract response texts from the language model's response(s). :param query_responses: The responses returned from the language model. :type query_responses: Union[List[Any], Any] :return: List of textual responses. :rtype: List[str] """ pass