add API and Openrouter
This commit is contained in:
parent
363421c61c
commit
4baaaa438c
@ -31,6 +31,15 @@ pip install -e .
|
||||
In order to use the framework, you need to have access to an LLM.
|
||||
Please follow the instructions in the [Controller README](graph_of_thoughts/controller/README.md) to configure the LLM of your choice.
|
||||
|
||||
### OpenRouter and OpenAI-compatible HTTP API
|
||||
|
||||
1. Install API extras: `pip install "graph_of_thoughts[api]"` (or `pip install -e ".[api]"` from a source checkout).
|
||||
2. Copy [`graph_of_thoughts/language_models/config.openrouter.example.yaml`](graph_of_thoughts/language_models/config.openrouter.example.yaml) to `config.openrouter.yaml`, add your [OpenRouter](https://openrouter.ai/) keys and model ids, and either place it in `graph_of_thoughts/language_models/` or point `OPENROUTER_CONFIG` at your file.
|
||||
3. Run the server: `got-openrouter-api` (or `python -m graph_of_thoughts.api`).
|
||||
4. Call `POST /v1/chat/completions` with a standard OpenAI-style JSON body (`messages`, optional `model`, `temperature`, `max_tokens`). The server runs a small Graph of Operations (generate multiple candidates, score, keep the best) via OpenRouter.
|
||||
|
||||
Details: [Language models README](graph_of_thoughts/language_models/README.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
The following code snippet shows how to use the framework to solve the sorting problem for a list of 32 numbers using a CoT-like approach.
|
||||
|
||||
1
graph_of_thoughts/api/__init__.py
Normal file
1
graph_of_thoughts/api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""HTTP API helpers (FastAPI) for running Graph of Thoughts with OpenRouter."""
|
||||
4
graph_of_thoughts/api/__main__.py
Normal file
4
graph_of_thoughts/api/__main__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from graph_of_thoughts.api.app import run
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
192
graph_of_thoughts/api/app.py
Normal file
192
graph_of_thoughts/api/app.py
Normal file
@ -0,0 +1,192 @@
|
||||
# Copyright (c) 2023 ETH Zurich.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license that can be
|
||||
# found in the LICENSE file.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from graph_of_thoughts.api.got_openai_pipeline import (
|
||||
ChatCompletionParser,
|
||||
ChatCompletionPrompter,
|
||||
build_default_chat_graph,
|
||||
extract_assistant_text,
|
||||
format_chat_messages,
|
||||
)
|
||||
from graph_of_thoughts.controller import Controller
|
||||
from graph_of_thoughts.language_models.openrouter import (
|
||||
OpenRouter,
|
||||
OpenRouterBadRequestError,
|
||||
OpenRouterRateLimitError,
|
||||
)
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"FastAPI and Pydantic are required for the HTTP API. "
|
||||
'Install with: pip install "graph_of_thoughts[api]"'
|
||||
) from e
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
n: Optional[int] = Field(default=1, ge=1, le=1)
|
||||
|
||||
|
||||
def _get_config_path() -> str:
|
||||
return os.environ.get(
|
||||
"OPENROUTER_CONFIG",
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"language_models",
|
||||
"config.openrouter.yaml",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _run_controller(lm: OpenRouter, user_text: str) -> str:
|
||||
graph = build_default_chat_graph(num_candidates=3)
|
||||
ctrl = Controller(
|
||||
lm,
|
||||
graph,
|
||||
ChatCompletionPrompter(),
|
||||
ChatCompletionParser(),
|
||||
{"input": user_text},
|
||||
)
|
||||
ctrl.run()
|
||||
return extract_assistant_text(ctrl.get_final_thoughts())
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Graph of Thoughts (OpenRouter)",
|
||||
version="0.1.0",
|
||||
description="OpenAI-compatible chat completions backed by Graph of Operations + OpenRouter.",
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def _startup() -> None:
|
||||
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
def list_models() -> Dict[str, Any]:
|
||||
path = _get_config_path()
|
||||
if not os.path.isfile(path):
|
||||
return {"object": "list", "data": []}
|
||||
from graph_of_thoughts.language_models.openrouter import load_openrouter_config
|
||||
|
||||
cfg = load_openrouter_config(path)
|
||||
models = cfg.get("models") or []
|
||||
if isinstance(models, str):
|
||||
models = [models]
|
||||
data = [
|
||||
{
|
||||
"id": m,
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "openrouter",
|
||||
}
|
||||
for m in models
|
||||
]
|
||||
return {"object": "list", "data": data}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
def chat_completions(body: ChatCompletionRequest) -> JSONResponse:
|
||||
if body.stream:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="stream=true is not supported; use stream=false.",
|
||||
)
|
||||
if body.n != 1:
|
||||
raise HTTPException(status_code=400, detail="Only n=1 is supported.")
|
||||
|
||||
path = _get_config_path()
|
||||
if not os.path.isfile(path):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"OpenRouter config not found at {path}. Set OPENROUTER_CONFIG.",
|
||||
)
|
||||
|
||||
lm = OpenRouter(config_path=path)
|
||||
user_text = format_chat_messages(
|
||||
[{"role": m.role, "content": m.content} for m in body.messages]
|
||||
)
|
||||
try:
|
||||
lm.set_request_overrides(
|
||||
model=body.model,
|
||||
temperature=body.temperature,
|
||||
max_tokens=body.max_tokens,
|
||||
)
|
||||
try:
|
||||
answer = _run_controller(lm, user_text)
|
||||
finally:
|
||||
lm.clear_request_overrides()
|
||||
except OpenRouterRateLimitError as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
except OpenRouterBadRequestError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
model_id = (
|
||||
body.model
|
||||
or lm.generation_model_id
|
||||
or lm.last_model_id
|
||||
or (lm.models[0] if lm.models else "openrouter")
|
||||
)
|
||||
resp_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"id": resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": now,
|
||||
"model": model_id,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": answer},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": lm.prompt_tokens,
|
||||
"completion_tokens": lm.completion_tokens,
|
||||
"total_tokens": lm.prompt_tokens + lm.completion_tokens,
|
||||
},
|
||||
}
|
||||
return JSONResponse(content=payload)
|
||||
|
||||
|
||||
def run() -> None:
|
||||
import uvicorn
|
||||
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", "8000"))
|
||||
uvicorn.run(
|
||||
"graph_of_thoughts.api.app:app",
|
||||
host=host,
|
||||
port=port,
|
||||
reload=os.environ.get("RELOAD", "").lower() in ("1", "true", "yes"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
123
graph_of_thoughts/api/got_openai_pipeline.py
Normal file
123
graph_of_thoughts/api/got_openai_pipeline.py
Normal file
@ -0,0 +1,123 @@
|
||||
# Copyright (c) 2023 ETH Zurich.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license that can be
|
||||
# found in the LICENSE file.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from graph_of_thoughts.operations import GraphOfOperations, operations
|
||||
from graph_of_thoughts.operations.thought import Thought
|
||||
from graph_of_thoughts.parser import Parser
|
||||
from graph_of_thoughts.prompter import Prompter
|
||||
|
||||
|
||||
def format_chat_messages(messages: List[Dict[str, str]]) -> str:
|
||||
parts: List[str] = []
|
||||
for m in messages:
|
||||
role = m.get("role", "user")
|
||||
content = m.get("content", "")
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
parts.append(f"{role.upper()}:\n{content}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
class ChatCompletionPrompter(Prompter):
|
||||
"""Prompter for a small generate → score → keep-best Graph of Operations."""
|
||||
|
||||
def generate_prompt(self, num_branches: int, **kwargs: Any) -> str:
|
||||
problem = kwargs.get("input", "")
|
||||
return (
|
||||
"You are a careful assistant. Read the conversation below and produce "
|
||||
"one candidate answer for the USER's latest needs.\n\n"
|
||||
f"{problem}\n\n"
|
||||
"Reply with your answer only, no preamble."
|
||||
)
|
||||
|
||||
def score_prompt(self, state_dicts: List[Dict], **kwargs: Any) -> str:
|
||||
lines = [
|
||||
"You evaluate candidate answers for the same problem. "
|
||||
"Score each candidate from 0 (worst) to 10 (best) on correctness, "
|
||||
"completeness, and relevance.",
|
||||
"",
|
||||
"Return ONLY a JSON array of numbers, one score per candidate in order, e.g. [7, 5, 9].",
|
||||
"",
|
||||
]
|
||||
for i, st in enumerate(state_dicts):
|
||||
cand = st.get("candidate", "")
|
||||
lines.append(f"Candidate {i}:\n{cand}\n")
|
||||
return "\n".join(lines)
|
||||
|
||||
def aggregation_prompt(self, state_dicts: List[Dict], **kwargs: Any) -> str:
|
||||
raise RuntimeError("aggregation_prompt is not used by the chat completion pipeline")
|
||||
|
||||
def improve_prompt(self, **kwargs: Any) -> str:
|
||||
raise RuntimeError("improve_prompt is not used by the chat completion pipeline")
|
||||
|
||||
def validation_prompt(self, **kwargs: Any) -> str:
|
||||
raise RuntimeError("validation_prompt is not used by the chat completion pipeline")
|
||||
|
||||
|
||||
class ChatCompletionParser(Parser):
|
||||
def parse_generate_answer(self, state: Dict, texts: List[str]) -> List[Dict]:
|
||||
out: List[Dict] = []
|
||||
for i, t in enumerate(texts):
|
||||
out.append({"candidate": (t or "").strip(), "branch_index": i})
|
||||
return out
|
||||
|
||||
def parse_score_answer(self, states: List[Dict], texts: List[str]) -> List[float]:
|
||||
raw = texts[0] if texts else ""
|
||||
scores = self._scores_from_text(raw, len(states))
|
||||
if len(scores) < len(states):
|
||||
scores.extend([0.0] * (len(states) - len(scores)))
|
||||
return scores[: len(states)]
|
||||
|
||||
def _scores_from_text(self, raw: str, n: int) -> List[float]:
|
||||
raw = raw.strip()
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
if isinstance(data, list):
|
||||
return [float(x) for x in data]
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
pass
|
||||
nums = re.findall(r"-?\d+(?:\.\d+)?", raw)
|
||||
return [float(x) for x in nums[:n]]
|
||||
|
||||
def parse_aggregation_answer(
|
||||
self, states: List[Dict], texts: List[str]
|
||||
) -> Union[Dict, List[Dict]]:
|
||||
raise RuntimeError("parse_aggregation_answer is not used")
|
||||
|
||||
def parse_improve_answer(self, state: Dict, texts: List[str]) -> Dict:
|
||||
raise RuntimeError("parse_improve_answer is not used")
|
||||
|
||||
def parse_validation_answer(self, state: Dict, texts: List[str]) -> bool:
|
||||
raise RuntimeError("parse_validation_answer is not used")
|
||||
|
||||
|
||||
def build_default_chat_graph(num_candidates: int = 3) -> GraphOfOperations:
|
||||
g = GraphOfOperations()
|
||||
g.append_operation(
|
||||
operations.Generate(
|
||||
num_branches_prompt=1, num_branches_response=num_candidates
|
||||
)
|
||||
)
|
||||
g.append_operation(operations.Score(combined_scoring=True))
|
||||
g.append_operation(operations.KeepBestN(1))
|
||||
return g
|
||||
|
||||
|
||||
def extract_assistant_text(final_thoughts_list: List[List[Thought]]) -> str:
|
||||
"""``get_final_thoughts`` returns one list per leaf operation; we take the first leaf's first thought."""
|
||||
if not final_thoughts_list:
|
||||
return ""
|
||||
thoughts = final_thoughts_list[0]
|
||||
if not thoughts:
|
||||
return ""
|
||||
state = thoughts[0].state or {}
|
||||
return str(state.get("candidate", ""))
|
||||
@ -4,6 +4,7 @@ The Language Models module is responsible for managing the large language models
|
||||
|
||||
Currently, the framework supports the following LLMs:
|
||||
- GPT-4 / GPT-3.5 (Remote - OpenAI API)
|
||||
- OpenRouter (Remote - [OpenRouter](https://openrouter.ai/) OpenAI-compatible API, multi-key / multi-model rotation)
|
||||
- LLaMA-2 (Local - HuggingFace Transformers)
|
||||
|
||||
The following sections describe how to instantiate individual LLMs and how to add new LLMs to the framework.
|
||||
@ -28,12 +29,26 @@ The following sections describe how to instantiate individual LLMs and how to ad
|
||||
|
||||
- Instantiate the language model based on the selected configuration key (predefined / custom).
|
||||
```python
|
||||
lm = controller.ChatGPT(
|
||||
"path/to/config.json",
|
||||
from graph_of_thoughts.language_models import ChatGPT
|
||||
|
||||
lm = ChatGPT(
|
||||
"path/to/config.json",
|
||||
model_name=<configuration key>
|
||||
)
|
||||
```
|
||||
|
||||
### OpenRouter
|
||||
- Copy `config.openrouter.example.yaml` (or `.json`) to `config.openrouter.yaml` next to this module, or pass an explicit path.
|
||||
- Set `api_keys` (list) and `models` (list). Each request picks a **random** key and a **random** model (uniform over the lists). If the HTTP API passes a `model` field, that model id is used for that request instead of a random one.
|
||||
- Optional: `http_referer` and `x_title` for OpenRouter attribution headers (see [OpenRouter docs](https://openrouter.ai/docs)).
|
||||
- HTTP **429** responses trigger exponential backoff and further rotation; **400** responses are retried a limited number of times with a new key/model pair, then surfaced as an error.
|
||||
|
||||
```python
|
||||
from graph_of_thoughts.language_models import OpenRouter
|
||||
|
||||
lm = OpenRouter("/path/to/config.openrouter.yaml")
|
||||
```
|
||||
|
||||
### LLaMA-2
|
||||
- Requires local hardware to run inference and a HuggingFace account.
|
||||
- Adjust the predefined `llama7b-hf`, `llama13b-hf` or `llama70b-hf` configurations or create a new configuration with an unique key.
|
||||
@ -50,8 +65,10 @@ lm = controller.ChatGPT(
|
||||
|
||||
- Instantiate the language model based on the selected configuration key (predefined / custom).
|
||||
```python
|
||||
lm = controller.Llama2HF(
|
||||
"path/to/config.json",
|
||||
from graph_of_thoughts.language_models import Llama2HF
|
||||
|
||||
lm = Llama2HF(
|
||||
"path/to/config.json",
|
||||
model_name=<configuration key>
|
||||
)
|
||||
```
|
||||
|
||||
@ -1,3 +1,10 @@
|
||||
from .abstract_language_model import AbstractLanguageModel
|
||||
from .chatgpt import ChatGPT
|
||||
from .llamachat_hf import Llama2HF
|
||||
from .openrouter import (
|
||||
OpenRouter,
|
||||
OpenRouterBadRequestError,
|
||||
OpenRouterError,
|
||||
OpenRouterRateLimitError,
|
||||
load_openrouter_config,
|
||||
)
|
||||
|
||||
@ -0,0 +1,21 @@
|
||||
{
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_keys": [
|
||||
"sk-or-v1-replace-me-1",
|
||||
"sk-or-v1-replace-me-2"
|
||||
],
|
||||
"models": [
|
||||
"openai/gpt-4o-mini",
|
||||
"anthropic/claude-3.5-haiku"
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4096,
|
||||
"stop": null,
|
||||
"prompt_token_cost": 0.0,
|
||||
"response_token_cost": 0.0,
|
||||
"max_retries_429": 8,
|
||||
"max_retries_400": 3,
|
||||
"base_backoff_seconds": 1.0,
|
||||
"http_referer": "",
|
||||
"x_title": ""
|
||||
}
|
||||
@ -0,0 +1,29 @@
|
||||
# Copy to config.openrouter.yaml (or set path explicitly) and fill in keys.
|
||||
# Per chat request, an API key and model are chosen at random (uniform) from the lists.
|
||||
|
||||
base_url: https://openrouter.ai/api/v1
|
||||
|
||||
api_keys:
|
||||
- sk-or-v1-replace-me-1
|
||||
- sk-or-v1-replace-me-2
|
||||
|
||||
models:
|
||||
- openai/gpt-4o-mini
|
||||
- anthropic/claude-3.5-haiku
|
||||
|
||||
temperature: 0.7
|
||||
max_tokens: 4096
|
||||
stop: null
|
||||
|
||||
# Optional cost accounting (set to 0 if unknown)
|
||||
prompt_token_cost: 0.0
|
||||
response_token_cost: 0.0
|
||||
|
||||
# Retries after HTTP 429 / 400 (each retry uses a fresh random key + model)
|
||||
max_retries_429: 8
|
||||
max_retries_400: 3
|
||||
base_backoff_seconds: 1.0
|
||||
|
||||
# Optional OpenRouter attribution headers (recommended by OpenRouter)
|
||||
http_referer: ""
|
||||
x_title: ""
|
||||
287
graph_of_thoughts/language_models/openrouter.py
Normal file
287
graph_of_thoughts/language_models/openrouter.py
Normal file
@ -0,0 +1,287 @@
|
||||
# Copyright (c) 2023 ETH Zurich.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license that can be
|
||||
# found in the LICENSE file.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from openai import APIStatusError, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from .abstract_language_model import AbstractLanguageModel
|
||||
|
||||
|
||||
class OpenRouterError(Exception):
|
||||
"""Base error for OpenRouter integration."""
|
||||
|
||||
|
||||
class OpenRouterBadRequestError(OpenRouterError):
|
||||
"""Raised when OpenRouter returns HTTP 400 after retries."""
|
||||
|
||||
|
||||
class OpenRouterRateLimitError(OpenRouterError):
|
||||
"""Raised when OpenRouter returns HTTP 429 after retries."""
|
||||
|
||||
|
||||
def load_openrouter_config(path: str) -> Dict[str, Any]:
|
||||
"""Load a YAML or JSON OpenRouter configuration file."""
|
||||
return _load_config_file(path)
|
||||
|
||||
|
||||
def _load_config_file(path: str) -> Dict[str, Any]:
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
if ext in (".yaml", ".yml"):
|
||||
data = yaml.safe_load(f)
|
||||
else:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Config at {path} must be a JSON/YAML object")
|
||||
return data
|
||||
|
||||
|
||||
class OpenRouter(AbstractLanguageModel):
|
||||
"""
|
||||
OpenRouter-backed language model with per-request rotation of API keys and models.
|
||||
|
||||
Configuration is loaded from YAML or JSON (see ``config.openrouter.example.yaml``).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str = "",
|
||||
model_name: str = "openrouter",
|
||||
cache: bool = False,
|
||||
) -> None:
|
||||
self._rotation_model_name = model_name
|
||||
self._request_overrides: Dict[str, Any] = {}
|
||||
super().__init__(config_path, model_name, cache)
|
||||
self._apply_openrouter_config()
|
||||
|
||||
def load_config(self, path: str) -> None:
|
||||
if path == "":
|
||||
path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"config.openrouter.yaml",
|
||||
)
|
||||
self.config_path = path
|
||||
self.config = _load_config_file(path)
|
||||
self.logger.debug("Loaded OpenRouter config from %s", path)
|
||||
|
||||
def _apply_openrouter_config(self) -> None:
|
||||
cfg = self.config
|
||||
self.base_url: str = cfg.get("base_url", "https://openrouter.ai/api/v1")
|
||||
keys = cfg.get("api_keys") or []
|
||||
if isinstance(keys, str):
|
||||
keys = [keys]
|
||||
self.api_keys: List[str] = [k for k in keys if k]
|
||||
if not self.api_keys:
|
||||
raise ValueError("OpenRouter config must define non-empty 'api_keys'")
|
||||
|
||||
models = cfg.get("models") or []
|
||||
if isinstance(models, str):
|
||||
models = [models]
|
||||
self.models: List[str] = [m for m in models if m]
|
||||
if not self.models:
|
||||
raise ValueError("OpenRouter config must define non-empty 'models'")
|
||||
|
||||
self.temperature: float = float(cfg.get("temperature", 1.0))
|
||||
self.max_tokens: int = int(cfg.get("max_tokens", 4096))
|
||||
self.stop: Union[str, List[str], None] = cfg.get("stop")
|
||||
self.prompt_token_cost: float = float(cfg.get("prompt_token_cost", 0.0))
|
||||
self.response_token_cost: float = float(cfg.get("response_token_cost", 0.0))
|
||||
|
||||
self.max_retries_429: int = int(cfg.get("max_retries_429", 8))
|
||||
self.max_retries_400: int = int(cfg.get("max_retries_400", 3))
|
||||
self.base_backoff_seconds: float = float(cfg.get("base_backoff_seconds", 1.0))
|
||||
|
||||
self.http_referer: str = cfg.get("http_referer", "") or os.getenv(
|
||||
"OPENROUTER_HTTP_REFERER", ""
|
||||
)
|
||||
self.x_title: str = cfg.get("x_title", "") or os.getenv("OPENROUTER_X_TITLE", "")
|
||||
|
||||
self.model_name = self._rotation_model_name
|
||||
self.last_model_id: Optional[str] = None
|
||||
self.generation_model_id: Optional[str] = None
|
||||
|
||||
def set_request_overrides(self, **kwargs: Any) -> None:
|
||||
"""Optional per-request parameters (used by the HTTP API). Cleared with :meth:`clear_request_overrides`."""
|
||||
self._request_overrides = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
def clear_request_overrides(self) -> None:
|
||||
self._request_overrides = {}
|
||||
|
||||
def _pick_key(self) -> str:
|
||||
return random.choice(self.api_keys)
|
||||
|
||||
def _pick_model(self, override: Optional[str]) -> str:
|
||||
if override:
|
||||
return override
|
||||
o = self._request_overrides.get("model")
|
||||
if o:
|
||||
return str(o)
|
||||
return random.choice(self.models)
|
||||
|
||||
def _effective_temperature(self) -> float:
|
||||
t = self._request_overrides.get("temperature")
|
||||
return float(t) if t is not None else self.temperature
|
||||
|
||||
def _effective_max_tokens(self) -> int:
|
||||
m = self._request_overrides.get("max_tokens")
|
||||
return int(m) if m is not None else self.max_tokens
|
||||
|
||||
def _client_for_key(self, api_key: str) -> OpenAI:
|
||||
headers: Dict[str, str] = {}
|
||||
if self.http_referer:
|
||||
headers["HTTP-Referer"] = self.http_referer
|
||||
if self.x_title:
|
||||
headers["X-Title"] = self.x_title
|
||||
return OpenAI(
|
||||
base_url=self.base_url,
|
||||
api_key=api_key,
|
||||
default_headers=headers or None,
|
||||
)
|
||||
|
||||
def _sleep_backoff(self, attempt: int) -> None:
|
||||
cap = 60.0
|
||||
delay = min(
|
||||
self.base_backoff_seconds * (2**attempt) + random.random(),
|
||||
cap,
|
||||
)
|
||||
self.logger.warning("Backing off %.2fs (attempt %d)", delay, attempt + 1)
|
||||
time.sleep(delay)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
num_responses: int = 1,
|
||||
model_override: Optional[str] = None,
|
||||
) -> ChatCompletion:
|
||||
"""
|
||||
Call OpenRouter chat completions with rotation and retries for 429/400.
|
||||
"""
|
||||
attempts_429 = 0
|
||||
attempts_400 = 0
|
||||
attempt = 0
|
||||
last_exc: Optional[Exception] = None
|
||||
|
||||
while True:
|
||||
api_key = self._pick_key()
|
||||
model_id = self._pick_model(model_override)
|
||||
client = self._client_for_key(api_key)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
temperature=self._effective_temperature(),
|
||||
max_tokens=self._effective_max_tokens(),
|
||||
n=num_responses,
|
||||
stop=self.stop,
|
||||
)
|
||||
if response.usage is not None:
|
||||
self.prompt_tokens += response.usage.prompt_tokens or 0
|
||||
self.completion_tokens += response.usage.completion_tokens or 0
|
||||
pt_k = float(self.prompt_tokens) / 1000.0
|
||||
ct_k = float(self.completion_tokens) / 1000.0
|
||||
self.cost = (
|
||||
self.prompt_token_cost * pt_k
|
||||
+ self.response_token_cost * ct_k
|
||||
)
|
||||
self.last_model_id = model_id
|
||||
if self.generation_model_id is None:
|
||||
self.generation_model_id = model_id
|
||||
self.logger.info(
|
||||
"OpenRouter response model=%s id=%s", model_id, response.id
|
||||
)
|
||||
return response
|
||||
except APIStatusError as e:
|
||||
last_exc = e
|
||||
code = e.status_code
|
||||
if code == 429:
|
||||
if attempts_429 >= self.max_retries_429:
|
||||
raise OpenRouterRateLimitError(
|
||||
f"OpenRouter rate limited after {attempts_429} retries: {e.message}"
|
||||
) from e
|
||||
attempts_429 += 1
|
||||
self._sleep_backoff(attempt)
|
||||
attempt += 1
|
||||
continue
|
||||
if code == 400:
|
||||
self.logger.warning(
|
||||
"OpenRouter HTTP 400 (will retry with rotated key/model if allowed): %s body=%s",
|
||||
e.message,
|
||||
e.body,
|
||||
)
|
||||
if attempts_400 >= self.max_retries_400:
|
||||
raise OpenRouterBadRequestError(
|
||||
f"OpenRouter bad request after {attempts_400} retries: {e.message}"
|
||||
) from e
|
||||
attempts_400 += 1
|
||||
attempt += 1
|
||||
time.sleep(random.uniform(0.2, 0.8))
|
||||
continue
|
||||
raise
|
||||
except Exception:
|
||||
self.logger.exception("Unexpected error calling OpenRouter")
|
||||
raise
|
||||
|
||||
def query(
|
||||
self, query: str, num_responses: int = 1
|
||||
) -> Union[List[ChatCompletion], ChatCompletion]:
|
||||
if self.cache and query in self.response_cache:
|
||||
return self.response_cache[query]
|
||||
|
||||
messages = [{"role": "user", "content": query}]
|
||||
model_ov = self._request_overrides.get("model")
|
||||
model_override = str(model_ov) if model_ov else None
|
||||
|
||||
if num_responses == 1:
|
||||
response = self.chat(messages, 1, model_override=model_override)
|
||||
else:
|
||||
response = []
|
||||
next_try = num_responses
|
||||
total_num_attempts = num_responses
|
||||
remaining = num_responses
|
||||
while remaining > 0 and total_num_attempts > 0:
|
||||
try:
|
||||
assert next_try > 0
|
||||
res = self.chat(
|
||||
messages, next_try, model_override=model_override
|
||||
)
|
||||
response.append(res)
|
||||
remaining -= next_try
|
||||
next_try = min(remaining, next_try)
|
||||
except Exception as e:
|
||||
next_try = max(1, (next_try + 1) // 2)
|
||||
self.logger.warning(
|
||||
"Error in OpenRouter query: %s, retrying with n=%s",
|
||||
e,
|
||||
next_try,
|
||||
)
|
||||
time.sleep(random.uniform(0.5, 2.0))
|
||||
total_num_attempts -= 1
|
||||
|
||||
if self.cache:
|
||||
self.response_cache[query] = response
|
||||
return response
|
||||
|
||||
def get_response_texts(
|
||||
self, query_response: Union[List[ChatCompletion], ChatCompletion]
|
||||
) -> List[str]:
|
||||
if not isinstance(query_response, list):
|
||||
query_response = [query_response]
|
||||
texts: List[str] = []
|
||||
for response in query_response:
|
||||
for choice in response.choices:
|
||||
c = choice.message.content
|
||||
texts.append(c if c is not None else "")
|
||||
return texts
|
||||
@ -22,6 +22,7 @@ classifiers = [
|
||||
dependencies = [
|
||||
"backoff>=2.2.1,<3.0.0",
|
||||
"openai>=1.0.0,<2.0.0",
|
||||
"pyyaml>=6.0.1,<7.0.0",
|
||||
"matplotlib>=3.7.1,<4.0.0",
|
||||
"numpy>=1.24.3,<2.0.0",
|
||||
"pandas>=2.0.3,<3.0.0",
|
||||
@ -33,7 +34,14 @@ dependencies = [
|
||||
"scipy>=1.10.1,<2.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
api = [
|
||||
"fastapi>=0.109.0,<1.0.0",
|
||||
"uvicorn[standard]>=0.27.0,<1.0.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/spcl/graph-of-thoughts"
|
||||
|
||||
[project.scripts]
|
||||
got-openrouter-api = "graph_of_thoughts.api.app:run"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user