add API and Openrouter
This commit is contained in:
parent
363421c61c
commit
4baaaa438c
@ -31,6 +31,15 @@ pip install -e .
|
|||||||
In order to use the framework, you need to have access to an LLM.
|
In order to use the framework, you need to have access to an LLM.
|
||||||
Please follow the instructions in the [Controller README](graph_of_thoughts/controller/README.md) to configure the LLM of your choice.
|
Please follow the instructions in the [Controller README](graph_of_thoughts/controller/README.md) to configure the LLM of your choice.
|
||||||
|
|
||||||
|
### OpenRouter and OpenAI-compatible HTTP API
|
||||||
|
|
||||||
|
1. Install API extras: `pip install "graph_of_thoughts[api]"` (or `pip install -e ".[api]"` from a source checkout).
|
||||||
|
2. Copy [`graph_of_thoughts/language_models/config.openrouter.example.yaml`](graph_of_thoughts/language_models/config.openrouter.example.yaml) to `config.openrouter.yaml`, add your [OpenRouter](https://openrouter.ai/) keys and model ids, and either place it in `graph_of_thoughts/language_models/` or point `OPENROUTER_CONFIG` at your file.
|
||||||
|
3. Run the server: `got-openrouter-api` (or `python -m graph_of_thoughts.api`).
|
||||||
|
4. Call `POST /v1/chat/completions` with a standard OpenAI-style JSON body (`messages`, optional `model`, `temperature`, `max_tokens`). The server runs a small Graph of Operations (generate multiple candidates, score, keep the best) via OpenRouter.
|
||||||
|
|
||||||
|
Details: [Language models README](graph_of_thoughts/language_models/README.md).
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
The following code snippet shows how to use the framework to solve the sorting problem for a list of 32 numbers using a CoT-like approach.
|
The following code snippet shows how to use the framework to solve the sorting problem for a list of 32 numbers using a CoT-like approach.
|
||||||
|
|||||||
1
graph_of_thoughts/api/__init__.py
Normal file
1
graph_of_thoughts/api/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""HTTP API helpers (FastAPI) for running Graph of Thoughts with OpenRouter."""
|
||||||
4
graph_of_thoughts/api/__main__.py
Normal file
4
graph_of_thoughts/api/__main__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from graph_of_thoughts.api.app import run
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
192
graph_of_thoughts/api/app.py
Normal file
192
graph_of_thoughts/api/app.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
# Copyright (c) 2023 ETH Zurich.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Use of this source code is governed by a BSD-style license that can be
|
||||||
|
# found in the LICENSE file.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from graph_of_thoughts.api.got_openai_pipeline import (
|
||||||
|
ChatCompletionParser,
|
||||||
|
ChatCompletionPrompter,
|
||||||
|
build_default_chat_graph,
|
||||||
|
extract_assistant_text,
|
||||||
|
format_chat_messages,
|
||||||
|
)
|
||||||
|
from graph_of_thoughts.controller import Controller
|
||||||
|
from graph_of_thoughts.language_models.openrouter import (
|
||||||
|
OpenRouter,
|
||||||
|
OpenRouterBadRequestError,
|
||||||
|
OpenRouterRateLimitError,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"FastAPI and Pydantic are required for the HTTP API. "
|
||||||
|
'Install with: pip install "graph_of_thoughts[api]"'
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
model: Optional[str] = None
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
n: Optional[int] = Field(default=1, ge=1, le=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_config_path() -> str:
|
||||||
|
return os.environ.get(
|
||||||
|
"OPENROUTER_CONFIG",
|
||||||
|
os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||||
|
"language_models",
|
||||||
|
"config.openrouter.yaml",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_controller(lm: OpenRouter, user_text: str) -> str:
|
||||||
|
graph = build_default_chat_graph(num_candidates=3)
|
||||||
|
ctrl = Controller(
|
||||||
|
lm,
|
||||||
|
graph,
|
||||||
|
ChatCompletionPrompter(),
|
||||||
|
ChatCompletionParser(),
|
||||||
|
{"input": user_text},
|
||||||
|
)
|
||||||
|
ctrl.run()
|
||||||
|
return extract_assistant_text(ctrl.get_final_thoughts())
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Graph of Thoughts (OpenRouter)",
|
||||||
|
version="0.1.0",
|
||||||
|
description="OpenAI-compatible chat completions backed by Graph of Operations + OpenRouter.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def _startup() -> None:
|
||||||
|
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
def list_models() -> Dict[str, Any]:
|
||||||
|
path = _get_config_path()
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
return {"object": "list", "data": []}
|
||||||
|
from graph_of_thoughts.language_models.openrouter import load_openrouter_config
|
||||||
|
|
||||||
|
cfg = load_openrouter_config(path)
|
||||||
|
models = cfg.get("models") or []
|
||||||
|
if isinstance(models, str):
|
||||||
|
models = [models]
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"id": m,
|
||||||
|
"object": "model",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"owned_by": "openrouter",
|
||||||
|
}
|
||||||
|
for m in models
|
||||||
|
]
|
||||||
|
return {"object": "list", "data": data}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
def chat_completions(body: ChatCompletionRequest) -> JSONResponse:
|
||||||
|
if body.stream:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="stream=true is not supported; use stream=false.",
|
||||||
|
)
|
||||||
|
if body.n != 1:
|
||||||
|
raise HTTPException(status_code=400, detail="Only n=1 is supported.")
|
||||||
|
|
||||||
|
path = _get_config_path()
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"OpenRouter config not found at {path}. Set OPENROUTER_CONFIG.",
|
||||||
|
)
|
||||||
|
|
||||||
|
lm = OpenRouter(config_path=path)
|
||||||
|
user_text = format_chat_messages(
|
||||||
|
[{"role": m.role, "content": m.content} for m in body.messages]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
lm.set_request_overrides(
|
||||||
|
model=body.model,
|
||||||
|
temperature=body.temperature,
|
||||||
|
max_tokens=body.max_tokens,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
answer = _run_controller(lm, user_text)
|
||||||
|
finally:
|
||||||
|
lm.clear_request_overrides()
|
||||||
|
except OpenRouterRateLimitError as e:
|
||||||
|
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||||
|
except OpenRouterBadRequestError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||||
|
|
||||||
|
model_id = (
|
||||||
|
body.model
|
||||||
|
or lm.generation_model_id
|
||||||
|
or lm.last_model_id
|
||||||
|
or (lm.models[0] if lm.models else "openrouter")
|
||||||
|
)
|
||||||
|
resp_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
|
now = int(time.time())
|
||||||
|
payload = {
|
||||||
|
"id": resp_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": now,
|
||||||
|
"model": model_id,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": answer},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": lm.prompt_tokens,
|
||||||
|
"completion_tokens": lm.completion_tokens,
|
||||||
|
"total_tokens": lm.prompt_tokens + lm.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return JSONResponse(content=payload)
|
||||||
|
|
||||||
|
|
||||||
|
def run() -> None:
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
host = os.environ.get("HOST", "0.0.0.0")
|
||||||
|
port = int(os.environ.get("PORT", "8000"))
|
||||||
|
uvicorn.run(
|
||||||
|
"graph_of_thoughts.api.app:app",
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
reload=os.environ.get("RELOAD", "").lower() in ("1", "true", "yes"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
123
graph_of_thoughts/api/got_openai_pipeline.py
Normal file
123
graph_of_thoughts/api/got_openai_pipeline.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
# Copyright (c) 2023 ETH Zurich.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Use of this source code is governed by a BSD-style license that can be
|
||||||
|
# found in the LICENSE file.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
from graph_of_thoughts.operations import GraphOfOperations, operations
|
||||||
|
from graph_of_thoughts.operations.thought import Thought
|
||||||
|
from graph_of_thoughts.parser import Parser
|
||||||
|
from graph_of_thoughts.prompter import Prompter
|
||||||
|
|
||||||
|
|
||||||
|
def format_chat_messages(messages: List[Dict[str, str]]) -> str:
|
||||||
|
parts: List[str] = []
|
||||||
|
for m in messages:
|
||||||
|
role = m.get("role", "user")
|
||||||
|
content = m.get("content", "")
|
||||||
|
if not isinstance(content, str):
|
||||||
|
content = str(content)
|
||||||
|
parts.append(f"{role.upper()}:\n{content}")
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionPrompter(Prompter):
|
||||||
|
"""Prompter for a small generate → score → keep-best Graph of Operations."""
|
||||||
|
|
||||||
|
def generate_prompt(self, num_branches: int, **kwargs: Any) -> str:
|
||||||
|
problem = kwargs.get("input", "")
|
||||||
|
return (
|
||||||
|
"You are a careful assistant. Read the conversation below and produce "
|
||||||
|
"one candidate answer for the USER's latest needs.\n\n"
|
||||||
|
f"{problem}\n\n"
|
||||||
|
"Reply with your answer only, no preamble."
|
||||||
|
)
|
||||||
|
|
||||||
|
def score_prompt(self, state_dicts: List[Dict], **kwargs: Any) -> str:
|
||||||
|
lines = [
|
||||||
|
"You evaluate candidate answers for the same problem. "
|
||||||
|
"Score each candidate from 0 (worst) to 10 (best) on correctness, "
|
||||||
|
"completeness, and relevance.",
|
||||||
|
"",
|
||||||
|
"Return ONLY a JSON array of numbers, one score per candidate in order, e.g. [7, 5, 9].",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
for i, st in enumerate(state_dicts):
|
||||||
|
cand = st.get("candidate", "")
|
||||||
|
lines.append(f"Candidate {i}:\n{cand}\n")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def aggregation_prompt(self, state_dicts: List[Dict], **kwargs: Any) -> str:
|
||||||
|
raise RuntimeError("aggregation_prompt is not used by the chat completion pipeline")
|
||||||
|
|
||||||
|
def improve_prompt(self, **kwargs: Any) -> str:
|
||||||
|
raise RuntimeError("improve_prompt is not used by the chat completion pipeline")
|
||||||
|
|
||||||
|
def validation_prompt(self, **kwargs: Any) -> str:
|
||||||
|
raise RuntimeError("validation_prompt is not used by the chat completion pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionParser(Parser):
|
||||||
|
def parse_generate_answer(self, state: Dict, texts: List[str]) -> List[Dict]:
|
||||||
|
out: List[Dict] = []
|
||||||
|
for i, t in enumerate(texts):
|
||||||
|
out.append({"candidate": (t or "").strip(), "branch_index": i})
|
||||||
|
return out
|
||||||
|
|
||||||
|
def parse_score_answer(self, states: List[Dict], texts: List[str]) -> List[float]:
|
||||||
|
raw = texts[0] if texts else ""
|
||||||
|
scores = self._scores_from_text(raw, len(states))
|
||||||
|
if len(scores) < len(states):
|
||||||
|
scores.extend([0.0] * (len(states) - len(scores)))
|
||||||
|
return scores[: len(states)]
|
||||||
|
|
||||||
|
def _scores_from_text(self, raw: str, n: int) -> List[float]:
|
||||||
|
raw = raw.strip()
|
||||||
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
if isinstance(data, list):
|
||||||
|
return [float(x) for x in data]
|
||||||
|
except (json.JSONDecodeError, ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
nums = re.findall(r"-?\d+(?:\.\d+)?", raw)
|
||||||
|
return [float(x) for x in nums[:n]]
|
||||||
|
|
||||||
|
def parse_aggregation_answer(
|
||||||
|
self, states: List[Dict], texts: List[str]
|
||||||
|
) -> Union[Dict, List[Dict]]:
|
||||||
|
raise RuntimeError("parse_aggregation_answer is not used")
|
||||||
|
|
||||||
|
def parse_improve_answer(self, state: Dict, texts: List[str]) -> Dict:
|
||||||
|
raise RuntimeError("parse_improve_answer is not used")
|
||||||
|
|
||||||
|
def parse_validation_answer(self, state: Dict, texts: List[str]) -> bool:
|
||||||
|
raise RuntimeError("parse_validation_answer is not used")
|
||||||
|
|
||||||
|
|
||||||
|
def build_default_chat_graph(num_candidates: int = 3) -> GraphOfOperations:
|
||||||
|
g = GraphOfOperations()
|
||||||
|
g.append_operation(
|
||||||
|
operations.Generate(
|
||||||
|
num_branches_prompt=1, num_branches_response=num_candidates
|
||||||
|
)
|
||||||
|
)
|
||||||
|
g.append_operation(operations.Score(combined_scoring=True))
|
||||||
|
g.append_operation(operations.KeepBestN(1))
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def extract_assistant_text(final_thoughts_list: List[List[Thought]]) -> str:
|
||||||
|
"""``get_final_thoughts`` returns one list per leaf operation; we take the first leaf's first thought."""
|
||||||
|
if not final_thoughts_list:
|
||||||
|
return ""
|
||||||
|
thoughts = final_thoughts_list[0]
|
||||||
|
if not thoughts:
|
||||||
|
return ""
|
||||||
|
state = thoughts[0].state or {}
|
||||||
|
return str(state.get("candidate", ""))
|
||||||
@ -4,6 +4,7 @@ The Language Models module is responsible for managing the large language models
|
|||||||
|
|
||||||
Currently, the framework supports the following LLMs:
|
Currently, the framework supports the following LLMs:
|
||||||
- GPT-4 / GPT-3.5 (Remote - OpenAI API)
|
- GPT-4 / GPT-3.5 (Remote - OpenAI API)
|
||||||
|
- OpenRouter (Remote - [OpenRouter](https://openrouter.ai/) OpenAI-compatible API, multi-key / multi-model rotation)
|
||||||
- LLaMA-2 (Local - HuggingFace Transformers)
|
- LLaMA-2 (Local - HuggingFace Transformers)
|
||||||
|
|
||||||
The following sections describe how to instantiate individual LLMs and how to add new LLMs to the framework.
|
The following sections describe how to instantiate individual LLMs and how to add new LLMs to the framework.
|
||||||
@ -28,12 +29,26 @@ The following sections describe how to instantiate individual LLMs and how to ad
|
|||||||
|
|
||||||
- Instantiate the language model based on the selected configuration key (predefined / custom).
|
- Instantiate the language model based on the selected configuration key (predefined / custom).
|
||||||
```python
|
```python
|
||||||
lm = controller.ChatGPT(
|
from graph_of_thoughts.language_models import ChatGPT
|
||||||
|
|
||||||
|
lm = ChatGPT(
|
||||||
"path/to/config.json",
|
"path/to/config.json",
|
||||||
model_name=<configuration key>
|
model_name=<configuration key>
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### OpenRouter
|
||||||
|
- Copy `config.openrouter.example.yaml` (or `.json`) to `config.openrouter.yaml` next to this module, or pass an explicit path.
|
||||||
|
- Set `api_keys` (list) and `models` (list). Each request picks a **random** key and a **random** model (uniform over the lists). If the HTTP API passes a `model` field, that model id is used for that request instead of a random one.
|
||||||
|
- Optional: `http_referer` and `x_title` for OpenRouter attribution headers (see [OpenRouter docs](https://openrouter.ai/docs)).
|
||||||
|
- HTTP **429** responses trigger exponential backoff and further rotation; **400** responses are retried a limited number of times with a new key/model pair, then surfaced as an error.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from graph_of_thoughts.language_models import OpenRouter
|
||||||
|
|
||||||
|
lm = OpenRouter("/path/to/config.openrouter.yaml")
|
||||||
|
```
|
||||||
|
|
||||||
### LLaMA-2
|
### LLaMA-2
|
||||||
- Requires local hardware to run inference and a HuggingFace account.
|
- Requires local hardware to run inference and a HuggingFace account.
|
||||||
- Adjust the predefined `llama7b-hf`, `llama13b-hf` or `llama70b-hf` configurations or create a new configuration with an unique key.
|
- Adjust the predefined `llama7b-hf`, `llama13b-hf` or `llama70b-hf` configurations or create a new configuration with an unique key.
|
||||||
@ -50,7 +65,9 @@ lm = controller.ChatGPT(
|
|||||||
|
|
||||||
- Instantiate the language model based on the selected configuration key (predefined / custom).
|
- Instantiate the language model based on the selected configuration key (predefined / custom).
|
||||||
```python
|
```python
|
||||||
lm = controller.Llama2HF(
|
from graph_of_thoughts.language_models import Llama2HF
|
||||||
|
|
||||||
|
lm = Llama2HF(
|
||||||
"path/to/config.json",
|
"path/to/config.json",
|
||||||
model_name=<configuration key>
|
model_name=<configuration key>
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,3 +1,10 @@
|
|||||||
from .abstract_language_model import AbstractLanguageModel
|
from .abstract_language_model import AbstractLanguageModel
|
||||||
from .chatgpt import ChatGPT
|
from .chatgpt import ChatGPT
|
||||||
from .llamachat_hf import Llama2HF
|
from .llamachat_hf import Llama2HF
|
||||||
|
from .openrouter import (
|
||||||
|
OpenRouter,
|
||||||
|
OpenRouterBadRequestError,
|
||||||
|
OpenRouterError,
|
||||||
|
OpenRouterRateLimitError,
|
||||||
|
load_openrouter_config,
|
||||||
|
)
|
||||||
|
|||||||
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"base_url": "https://openrouter.ai/api/v1",
|
||||||
|
"api_keys": [
|
||||||
|
"sk-or-v1-replace-me-1",
|
||||||
|
"sk-or-v1-replace-me-2"
|
||||||
|
],
|
||||||
|
"models": [
|
||||||
|
"openai/gpt-4o-mini",
|
||||||
|
"anthropic/claude-3.5-haiku"
|
||||||
|
],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"stop": null,
|
||||||
|
"prompt_token_cost": 0.0,
|
||||||
|
"response_token_cost": 0.0,
|
||||||
|
"max_retries_429": 8,
|
||||||
|
"max_retries_400": 3,
|
||||||
|
"base_backoff_seconds": 1.0,
|
||||||
|
"http_referer": "",
|
||||||
|
"x_title": ""
|
||||||
|
}
|
||||||
@ -0,0 +1,29 @@
|
|||||||
|
# Copy to config.openrouter.yaml (or set path explicitly) and fill in keys.
|
||||||
|
# Per chat request, an API key and model are chosen at random (uniform) from the lists.
|
||||||
|
|
||||||
|
base_url: https://openrouter.ai/api/v1
|
||||||
|
|
||||||
|
api_keys:
|
||||||
|
- sk-or-v1-replace-me-1
|
||||||
|
- sk-or-v1-replace-me-2
|
||||||
|
|
||||||
|
models:
|
||||||
|
- openai/gpt-4o-mini
|
||||||
|
- anthropic/claude-3.5-haiku
|
||||||
|
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 4096
|
||||||
|
stop: null
|
||||||
|
|
||||||
|
# Optional cost accounting (set to 0 if unknown)
|
||||||
|
prompt_token_cost: 0.0
|
||||||
|
response_token_cost: 0.0
|
||||||
|
|
||||||
|
# Retries after HTTP 429 / 400 (each retry uses a fresh random key + model)
|
||||||
|
max_retries_429: 8
|
||||||
|
max_retries_400: 3
|
||||||
|
base_backoff_seconds: 1.0
|
||||||
|
|
||||||
|
# Optional OpenRouter attribution headers (recommended by OpenRouter)
|
||||||
|
http_referer: ""
|
||||||
|
x_title: ""
|
||||||
287
graph_of_thoughts/language_models/openrouter.py
Normal file
287
graph_of_thoughts/language_models/openrouter.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
# Copyright (c) 2023 ETH Zurich.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Use of this source code is governed by a BSD-style license that can be
|
||||||
|
# found in the LICENSE file.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from openai import APIStatusError, OpenAI
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
|
from .abstract_language_model import AbstractLanguageModel
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterError(Exception):
|
||||||
|
"""Base error for OpenRouter integration."""
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterBadRequestError(OpenRouterError):
|
||||||
|
"""Raised when OpenRouter returns HTTP 400 after retries."""
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterRateLimitError(OpenRouterError):
|
||||||
|
"""Raised when OpenRouter returns HTTP 429 after retries."""
|
||||||
|
|
||||||
|
|
||||||
|
def load_openrouter_config(path: str) -> Dict[str, Any]:
|
||||||
|
"""Load a YAML or JSON OpenRouter configuration file."""
|
||||||
|
return _load_config_file(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_config_file(path: str) -> Dict[str, Any]:
|
||||||
|
ext = os.path.splitext(path)[1].lower()
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
if ext in (".yaml", ".yml"):
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
else:
|
||||||
|
data = json.load(f)
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError(f"Config at {path} must be a JSON/YAML object")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouter(AbstractLanguageModel):
|
||||||
|
"""
|
||||||
|
OpenRouter-backed language model with per-request rotation of API keys and models.
|
||||||
|
|
||||||
|
Configuration is loaded from YAML or JSON (see ``config.openrouter.example.yaml``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_path: str = "",
|
||||||
|
model_name: str = "openrouter",
|
||||||
|
cache: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self._rotation_model_name = model_name
|
||||||
|
self._request_overrides: Dict[str, Any] = {}
|
||||||
|
super().__init__(config_path, model_name, cache)
|
||||||
|
self._apply_openrouter_config()
|
||||||
|
|
||||||
|
def load_config(self, path: str) -> None:
|
||||||
|
if path == "":
|
||||||
|
path = os.path.join(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)),
|
||||||
|
"config.openrouter.yaml",
|
||||||
|
)
|
||||||
|
self.config_path = path
|
||||||
|
self.config = _load_config_file(path)
|
||||||
|
self.logger.debug("Loaded OpenRouter config from %s", path)
|
||||||
|
|
||||||
|
def _apply_openrouter_config(self) -> None:
|
||||||
|
cfg = self.config
|
||||||
|
self.base_url: str = cfg.get("base_url", "https://openrouter.ai/api/v1")
|
||||||
|
keys = cfg.get("api_keys") or []
|
||||||
|
if isinstance(keys, str):
|
||||||
|
keys = [keys]
|
||||||
|
self.api_keys: List[str] = [k for k in keys if k]
|
||||||
|
if not self.api_keys:
|
||||||
|
raise ValueError("OpenRouter config must define non-empty 'api_keys'")
|
||||||
|
|
||||||
|
models = cfg.get("models") or []
|
||||||
|
if isinstance(models, str):
|
||||||
|
models = [models]
|
||||||
|
self.models: List[str] = [m for m in models if m]
|
||||||
|
if not self.models:
|
||||||
|
raise ValueError("OpenRouter config must define non-empty 'models'")
|
||||||
|
|
||||||
|
self.temperature: float = float(cfg.get("temperature", 1.0))
|
||||||
|
self.max_tokens: int = int(cfg.get("max_tokens", 4096))
|
||||||
|
self.stop: Union[str, List[str], None] = cfg.get("stop")
|
||||||
|
self.prompt_token_cost: float = float(cfg.get("prompt_token_cost", 0.0))
|
||||||
|
self.response_token_cost: float = float(cfg.get("response_token_cost", 0.0))
|
||||||
|
|
||||||
|
self.max_retries_429: int = int(cfg.get("max_retries_429", 8))
|
||||||
|
self.max_retries_400: int = int(cfg.get("max_retries_400", 3))
|
||||||
|
self.base_backoff_seconds: float = float(cfg.get("base_backoff_seconds", 1.0))
|
||||||
|
|
||||||
|
self.http_referer: str = cfg.get("http_referer", "") or os.getenv(
|
||||||
|
"OPENROUTER_HTTP_REFERER", ""
|
||||||
|
)
|
||||||
|
self.x_title: str = cfg.get("x_title", "") or os.getenv("OPENROUTER_X_TITLE", "")
|
||||||
|
|
||||||
|
self.model_name = self._rotation_model_name
|
||||||
|
self.last_model_id: Optional[str] = None
|
||||||
|
self.generation_model_id: Optional[str] = None
|
||||||
|
|
||||||
|
def set_request_overrides(self, **kwargs: Any) -> None:
|
||||||
|
"""Optional per-request parameters (used by the HTTP API). Cleared with :meth:`clear_request_overrides`."""
|
||||||
|
self._request_overrides = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
def clear_request_overrides(self) -> None:
|
||||||
|
self._request_overrides = {}
|
||||||
|
|
||||||
|
def _pick_key(self) -> str:
|
||||||
|
return random.choice(self.api_keys)
|
||||||
|
|
||||||
|
def _pick_model(self, override: Optional[str]) -> str:
|
||||||
|
if override:
|
||||||
|
return override
|
||||||
|
o = self._request_overrides.get("model")
|
||||||
|
if o:
|
||||||
|
return str(o)
|
||||||
|
return random.choice(self.models)
|
||||||
|
|
||||||
|
def _effective_temperature(self) -> float:
|
||||||
|
t = self._request_overrides.get("temperature")
|
||||||
|
return float(t) if t is not None else self.temperature
|
||||||
|
|
||||||
|
def _effective_max_tokens(self) -> int:
|
||||||
|
m = self._request_overrides.get("max_tokens")
|
||||||
|
return int(m) if m is not None else self.max_tokens
|
||||||
|
|
||||||
|
def _client_for_key(self, api_key: str) -> OpenAI:
|
||||||
|
headers: Dict[str, str] = {}
|
||||||
|
if self.http_referer:
|
||||||
|
headers["HTTP-Referer"] = self.http_referer
|
||||||
|
if self.x_title:
|
||||||
|
headers["X-Title"] = self.x_title
|
||||||
|
return OpenAI(
|
||||||
|
base_url=self.base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
default_headers=headers or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sleep_backoff(self, attempt: int) -> None:
|
||||||
|
cap = 60.0
|
||||||
|
delay = min(
|
||||||
|
self.base_backoff_seconds * (2**attempt) + random.random(),
|
||||||
|
cap,
|
||||||
|
)
|
||||||
|
self.logger.warning("Backing off %.2fs (attempt %d)", delay, attempt + 1)
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
num_responses: int = 1,
|
||||||
|
model_override: Optional[str] = None,
|
||||||
|
) -> ChatCompletion:
|
||||||
|
"""
|
||||||
|
Call OpenRouter chat completions with rotation and retries for 429/400.
|
||||||
|
"""
|
||||||
|
attempts_429 = 0
|
||||||
|
attempts_400 = 0
|
||||||
|
attempt = 0
|
||||||
|
last_exc: Optional[Exception] = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
api_key = self._pick_key()
|
||||||
|
model_id = self._pick_model(model_override)
|
||||||
|
client = self._client_for_key(api_key)
|
||||||
|
try:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model_id,
|
||||||
|
messages=messages,
|
||||||
|
temperature=self._effective_temperature(),
|
||||||
|
max_tokens=self._effective_max_tokens(),
|
||||||
|
n=num_responses,
|
||||||
|
stop=self.stop,
|
||||||
|
)
|
||||||
|
if response.usage is not None:
|
||||||
|
self.prompt_tokens += response.usage.prompt_tokens or 0
|
||||||
|
self.completion_tokens += response.usage.completion_tokens or 0
|
||||||
|
pt_k = float(self.prompt_tokens) / 1000.0
|
||||||
|
ct_k = float(self.completion_tokens) / 1000.0
|
||||||
|
self.cost = (
|
||||||
|
self.prompt_token_cost * pt_k
|
||||||
|
+ self.response_token_cost * ct_k
|
||||||
|
)
|
||||||
|
self.last_model_id = model_id
|
||||||
|
if self.generation_model_id is None:
|
||||||
|
self.generation_model_id = model_id
|
||||||
|
self.logger.info(
|
||||||
|
"OpenRouter response model=%s id=%s", model_id, response.id
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except APIStatusError as e:
|
||||||
|
last_exc = e
|
||||||
|
code = e.status_code
|
||||||
|
if code == 429:
|
||||||
|
if attempts_429 >= self.max_retries_429:
|
||||||
|
raise OpenRouterRateLimitError(
|
||||||
|
f"OpenRouter rate limited after {attempts_429} retries: {e.message}"
|
||||||
|
) from e
|
||||||
|
attempts_429 += 1
|
||||||
|
self._sleep_backoff(attempt)
|
||||||
|
attempt += 1
|
||||||
|
continue
|
||||||
|
if code == 400:
|
||||||
|
self.logger.warning(
|
||||||
|
"OpenRouter HTTP 400 (will retry with rotated key/model if allowed): %s body=%s",
|
||||||
|
e.message,
|
||||||
|
e.body,
|
||||||
|
)
|
||||||
|
if attempts_400 >= self.max_retries_400:
|
||||||
|
raise OpenRouterBadRequestError(
|
||||||
|
f"OpenRouter bad request after {attempts_400} retries: {e.message}"
|
||||||
|
) from e
|
||||||
|
attempts_400 += 1
|
||||||
|
attempt += 1
|
||||||
|
time.sleep(random.uniform(0.2, 0.8))
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
self.logger.exception("Unexpected error calling OpenRouter")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self, query: str, num_responses: int = 1
|
||||||
|
) -> Union[List[ChatCompletion], ChatCompletion]:
|
||||||
|
if self.cache and query in self.response_cache:
|
||||||
|
return self.response_cache[query]
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": query}]
|
||||||
|
model_ov = self._request_overrides.get("model")
|
||||||
|
model_override = str(model_ov) if model_ov else None
|
||||||
|
|
||||||
|
if num_responses == 1:
|
||||||
|
response = self.chat(messages, 1, model_override=model_override)
|
||||||
|
else:
|
||||||
|
response = []
|
||||||
|
next_try = num_responses
|
||||||
|
total_num_attempts = num_responses
|
||||||
|
remaining = num_responses
|
||||||
|
while remaining > 0 and total_num_attempts > 0:
|
||||||
|
try:
|
||||||
|
assert next_try > 0
|
||||||
|
res = self.chat(
|
||||||
|
messages, next_try, model_override=model_override
|
||||||
|
)
|
||||||
|
response.append(res)
|
||||||
|
remaining -= next_try
|
||||||
|
next_try = min(remaining, next_try)
|
||||||
|
except Exception as e:
|
||||||
|
next_try = max(1, (next_try + 1) // 2)
|
||||||
|
self.logger.warning(
|
||||||
|
"Error in OpenRouter query: %s, retrying with n=%s",
|
||||||
|
e,
|
||||||
|
next_try,
|
||||||
|
)
|
||||||
|
time.sleep(random.uniform(0.5, 2.0))
|
||||||
|
total_num_attempts -= 1
|
||||||
|
|
||||||
|
if self.cache:
|
||||||
|
self.response_cache[query] = response
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_response_texts(
|
||||||
|
self, query_response: Union[List[ChatCompletion], ChatCompletion]
|
||||||
|
) -> List[str]:
|
||||||
|
if not isinstance(query_response, list):
|
||||||
|
query_response = [query_response]
|
||||||
|
texts: List[str] = []
|
||||||
|
for response in query_response:
|
||||||
|
for choice in response.choices:
|
||||||
|
c = choice.message.content
|
||||||
|
texts.append(c if c is not None else "")
|
||||||
|
return texts
|
||||||
@ -22,6 +22,7 @@ classifiers = [
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"backoff>=2.2.1,<3.0.0",
|
"backoff>=2.2.1,<3.0.0",
|
||||||
"openai>=1.0.0,<2.0.0",
|
"openai>=1.0.0,<2.0.0",
|
||||||
|
"pyyaml>=6.0.1,<7.0.0",
|
||||||
"matplotlib>=3.7.1,<4.0.0",
|
"matplotlib>=3.7.1,<4.0.0",
|
||||||
"numpy>=1.24.3,<2.0.0",
|
"numpy>=1.24.3,<2.0.0",
|
||||||
"pandas>=2.0.3,<3.0.0",
|
"pandas>=2.0.3,<3.0.0",
|
||||||
@ -33,7 +34,14 @@ dependencies = [
|
|||||||
"scipy>=1.10.1,<2.0.0",
|
"scipy>=1.10.1,<2.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
api = [
|
||||||
|
"fastapi>=0.109.0,<1.0.0",
|
||||||
|
"uvicorn[standard]>=0.27.0,<1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Homepage = "https://github.com/spcl/graph-of-thoughts"
|
Homepage = "https://github.com/spcl/graph-of-thoughts"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
got-openrouter-api = "graph_of_thoughts.api.app:run"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user