193 lines
5.2 KiB
Python
193 lines
5.2 KiB
Python
# Copyright (c) 2023 ETH Zurich.
|
|
# All rights reserved.
|
|
#
|
|
# Use of this source code is governed by a BSD-style license that can be
|
|
# found in the LICENSE file.
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import time
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from graph_of_thoughts.api.got_openai_pipeline import (
|
|
ChatCompletionParser,
|
|
ChatCompletionPrompter,
|
|
build_default_chat_graph,
|
|
extract_assistant_text,
|
|
format_chat_messages,
|
|
)
|
|
from graph_of_thoughts.controller import Controller
|
|
from graph_of_thoughts.language_models.openrouter import (
|
|
OpenRouter,
|
|
OpenRouterBadRequestError,
|
|
OpenRouterRateLimitError,
|
|
)
|
|
|
|
try:
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel, Field
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"FastAPI and Pydantic are required for the HTTP API. "
|
|
'Install with: pip install "graph_of_thoughts[api]"'
|
|
) from e
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
model: Optional[str] = None
|
|
messages: List[ChatMessage]
|
|
temperature: Optional[float] = None
|
|
max_tokens: Optional[int] = None
|
|
stream: Optional[bool] = False
|
|
n: Optional[int] = Field(default=1, ge=1, le=1)
|
|
|
|
|
|
def _get_config_path() -> str:
|
|
return os.environ.get(
|
|
"OPENROUTER_CONFIG",
|
|
os.path.join(
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
|
"language_models",
|
|
"config.openrouter.yaml",
|
|
),
|
|
)
|
|
|
|
|
|
def _run_controller(lm: OpenRouter, user_text: str) -> str:
|
|
graph = build_default_chat_graph(num_candidates=3)
|
|
ctrl = Controller(
|
|
lm,
|
|
graph,
|
|
ChatCompletionPrompter(),
|
|
ChatCompletionParser(),
|
|
{"input": user_text},
|
|
)
|
|
ctrl.run()
|
|
return extract_assistant_text(ctrl.get_final_thoughts())
|
|
|
|
|
|
app = FastAPI(
|
|
title="Graph of Thoughts (OpenRouter)",
|
|
version="0.1.0",
|
|
description="OpenAI-compatible chat completions backed by Graph of Operations + OpenRouter.",
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
def _startup() -> None:
|
|
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
|
|
|
|
|
|
@app.get("/v1/models")
|
|
def list_models() -> Dict[str, Any]:
|
|
path = _get_config_path()
|
|
if not os.path.isfile(path):
|
|
return {"object": "list", "data": []}
|
|
from graph_of_thoughts.language_models.openrouter import load_openrouter_config
|
|
|
|
cfg = load_openrouter_config(path)
|
|
models = cfg.get("models") or []
|
|
if isinstance(models, str):
|
|
models = [models]
|
|
data = [
|
|
{
|
|
"id": m,
|
|
"object": "model",
|
|
"created": int(time.time()),
|
|
"owned_by": "openrouter",
|
|
}
|
|
for m in models
|
|
]
|
|
return {"object": "list", "data": data}
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
def chat_completions(body: ChatCompletionRequest) -> JSONResponse:
|
|
if body.stream:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="stream=true is not supported; use stream=false.",
|
|
)
|
|
if body.n != 1:
|
|
raise HTTPException(status_code=400, detail="Only n=1 is supported.")
|
|
|
|
path = _get_config_path()
|
|
if not os.path.isfile(path):
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"OpenRouter config not found at {path}. Set OPENROUTER_CONFIG.",
|
|
)
|
|
|
|
lm = OpenRouter(config_path=path)
|
|
user_text = format_chat_messages(
|
|
[{"role": m.role, "content": m.content} for m in body.messages]
|
|
)
|
|
try:
|
|
lm.set_request_overrides(
|
|
model=body.model,
|
|
temperature=body.temperature,
|
|
max_tokens=body.max_tokens,
|
|
)
|
|
try:
|
|
answer = _run_controller(lm, user_text)
|
|
finally:
|
|
lm.clear_request_overrides()
|
|
except OpenRouterRateLimitError as e:
|
|
raise HTTPException(status_code=429, detail=str(e)) from e
|
|
except OpenRouterBadRequestError as e:
|
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
|
|
model_id = (
|
|
body.model
|
|
or lm.generation_model_id
|
|
or lm.last_model_id
|
|
or (lm.models[0] if lm.models else "openrouter")
|
|
)
|
|
resp_id = f"chatcmpl-{uuid.uuid4().hex}"
|
|
now = int(time.time())
|
|
payload = {
|
|
"id": resp_id,
|
|
"object": "chat.completion",
|
|
"created": now,
|
|
"model": model_id,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": answer},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": lm.prompt_tokens,
|
|
"completion_tokens": lm.completion_tokens,
|
|
"total_tokens": lm.prompt_tokens + lm.completion_tokens,
|
|
},
|
|
}
|
|
return JSONResponse(content=payload)
|
|
|
|
|
|
def run() -> None:
|
|
import uvicorn
|
|
|
|
host = os.environ.get("HOST", "0.0.0.0")
|
|
port = int(os.environ.get("PORT", "8000"))
|
|
uvicorn.run(
|
|
"graph_of_thoughts.api.app:app",
|
|
host=host,
|
|
port=port,
|
|
reload=os.environ.get("RELOAD", "").lower() in ("1", "true", "yes"),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|