2026-03-19 23:18:12 +00:00

193 lines
5.2 KiB
Python

# Copyright (c) 2023 ETH Zurich.
# All rights reserved.
#
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
from __future__ import annotations
import logging
import os
import time
import uuid
from typing import Any, Dict, List, Optional
from graph_of_thoughts.api.got_openai_pipeline import (
ChatCompletionParser,
ChatCompletionPrompter,
build_default_chat_graph,
extract_assistant_text,
format_chat_messages,
)
from graph_of_thoughts.controller import Controller
from graph_of_thoughts.language_models.openrouter import (
OpenRouter,
OpenRouterBadRequestError,
OpenRouterRateLimitError,
)
try:
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
except ImportError as e:
raise ImportError(
"FastAPI and Pydantic are required for the HTTP API. "
'Install with: pip install "graph_of_thoughts[api]"'
) from e
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: Optional[str] = None
messages: List[ChatMessage]
temperature: Optional[float] = None
max_tokens: Optional[int] = None
stream: Optional[bool] = False
n: Optional[int] = Field(default=1, ge=1, le=1)
def _get_config_path() -> str:
return os.environ.get(
"OPENROUTER_CONFIG",
os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"language_models",
"config.openrouter.yaml",
),
)
def _run_controller(lm: OpenRouter, user_text: str) -> str:
graph = build_default_chat_graph(num_candidates=3)
ctrl = Controller(
lm,
graph,
ChatCompletionPrompter(),
ChatCompletionParser(),
{"input": user_text},
)
ctrl.run()
return extract_assistant_text(ctrl.get_final_thoughts())
app = FastAPI(
title="Graph of Thoughts (OpenRouter)",
version="0.1.0",
description="OpenAI-compatible chat completions backed by Graph of Operations + OpenRouter.",
)
@app.on_event("startup")
def _startup() -> None:
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
@app.get("/v1/models")
def list_models() -> Dict[str, Any]:
path = _get_config_path()
if not os.path.isfile(path):
return {"object": "list", "data": []}
from graph_of_thoughts.language_models.openrouter import load_openrouter_config
cfg = load_openrouter_config(path)
models = cfg.get("models") or []
if isinstance(models, str):
models = [models]
data = [
{
"id": m,
"object": "model",
"created": int(time.time()),
"owned_by": "openrouter",
}
for m in models
]
return {"object": "list", "data": data}
@app.post("/v1/chat/completions")
def chat_completions(body: ChatCompletionRequest) -> JSONResponse:
if body.stream:
raise HTTPException(
status_code=400,
detail="stream=true is not supported; use stream=false.",
)
if body.n != 1:
raise HTTPException(status_code=400, detail="Only n=1 is supported.")
path = _get_config_path()
if not os.path.isfile(path):
raise HTTPException(
status_code=500,
detail=f"OpenRouter config not found at {path}. Set OPENROUTER_CONFIG.",
)
lm = OpenRouter(config_path=path)
user_text = format_chat_messages(
[{"role": m.role, "content": m.content} for m in body.messages]
)
try:
lm.set_request_overrides(
model=body.model,
temperature=body.temperature,
max_tokens=body.max_tokens,
)
try:
answer = _run_controller(lm, user_text)
finally:
lm.clear_request_overrides()
except OpenRouterRateLimitError as e:
raise HTTPException(status_code=429, detail=str(e)) from e
except OpenRouterBadRequestError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
model_id = (
body.model
or lm.generation_model_id
or lm.last_model_id
or (lm.models[0] if lm.models else "openrouter")
)
resp_id = f"chatcmpl-{uuid.uuid4().hex}"
now = int(time.time())
payload = {
"id": resp_id,
"object": "chat.completion",
"created": now,
"model": model_id,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": answer},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": lm.prompt_tokens,
"completion_tokens": lm.completion_tokens,
"total_tokens": lm.prompt_tokens + lm.completion_tokens,
},
}
return JSONResponse(content=payload)
def run() -> None:
import uvicorn
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", "8000"))
uvicorn.run(
"graph_of_thoughts.api.app:app",
host=host,
port=port,
reload=os.environ.get("RELOAD", "").lower() in ("1", "true", "yes"),
)
if __name__ == "__main__":
run()