# Copyright (c) 2023 ETH Zurich. # All rights reserved. # # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. from __future__ import annotations import logging import os import time import uuid from typing import Any, Dict, List, Optional from graph_of_thoughts.api.got_openai_pipeline import ( ChatCompletionParser, ChatCompletionPrompter, build_default_chat_graph, extract_assistant_text, format_chat_messages, ) from graph_of_thoughts.controller import Controller from graph_of_thoughts.language_models.openrouter import ( OpenRouter, OpenRouterBadRequestError, OpenRouterRateLimitError, ) try: from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel, Field except ImportError as e: raise ImportError( "FastAPI and Pydantic are required for the HTTP API. " 'Install with: pip install "graph_of_thoughts[api]"' ) from e class ChatMessage(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: Optional[str] = None messages: List[ChatMessage] temperature: Optional[float] = None max_tokens: Optional[int] = None stream: Optional[bool] = False n: Optional[int] = Field(default=1, ge=1, le=1) def _get_config_path() -> str: return os.environ.get( "OPENROUTER_CONFIG", os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "language_models", "config.openrouter.yaml", ), ) def _run_controller(lm: OpenRouter, user_text: str) -> str: graph = build_default_chat_graph(num_candidates=3) ctrl = Controller( lm, graph, ChatCompletionPrompter(), ChatCompletionParser(), {"input": user_text}, ) ctrl.run() return extract_assistant_text(ctrl.get_final_thoughts()) app = FastAPI( title="Graph of Thoughts (OpenRouter)", version="0.1.0", description="OpenAI-compatible chat completions backed by Graph of Operations + OpenRouter.", ) @app.on_event("startup") def _startup() -> None: logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO")) @app.get("/v1/models") def list_models() -> Dict[str, Any]: path = _get_config_path() if not os.path.isfile(path): return {"object": "list", "data": []} from graph_of_thoughts.language_models.openrouter import load_openrouter_config cfg = load_openrouter_config(path) models = cfg.get("models") or [] if isinstance(models, str): models = [models] data = [ { "id": m, "object": "model", "created": int(time.time()), "owned_by": "openrouter", } for m in models ] return {"object": "list", "data": data} @app.post("/v1/chat/completions") def chat_completions(body: ChatCompletionRequest) -> JSONResponse: if body.stream: raise HTTPException( status_code=400, detail="stream=true is not supported; use stream=false.", ) if body.n != 1: raise HTTPException(status_code=400, detail="Only n=1 is supported.") path = _get_config_path() if not os.path.isfile(path): raise HTTPException( status_code=500, detail=f"OpenRouter config not found at {path}. Set OPENROUTER_CONFIG.", ) lm = OpenRouter(config_path=path) user_text = format_chat_messages( [{"role": m.role, "content": m.content} for m in body.messages] ) try: lm.set_request_overrides( model=body.model, temperature=body.temperature, max_tokens=body.max_tokens, ) try: answer = _run_controller(lm, user_text) finally: lm.clear_request_overrides() except OpenRouterRateLimitError as e: raise HTTPException(status_code=429, detail=str(e)) from e except OpenRouterBadRequestError as e: raise HTTPException(status_code=400, detail=str(e)) from e model_id = ( body.model or lm.generation_model_id or lm.last_model_id or (lm.models[0] if lm.models else "openrouter") ) resp_id = f"chatcmpl-{uuid.uuid4().hex}" now = int(time.time()) payload = { "id": resp_id, "object": "chat.completion", "created": now, "model": model_id, "choices": [ { "index": 0, "message": {"role": "assistant", "content": answer}, "finish_reason": "stop", } ], "usage": { "prompt_tokens": lm.prompt_tokens, "completion_tokens": lm.completion_tokens, "total_tokens": lm.prompt_tokens + lm.completion_tokens, }, } return JSONResponse(content=payload) def run() -> None: import uvicorn host = os.environ.get("HOST", "0.0.0.0") port = int(os.environ.get("PORT", "8000")) uvicorn.run( "graph_of_thoughts.api.app:app", host=host, port=port, reload=os.environ.get("RELOAD", "").lower() in ("1", "true", "yes"), ) if __name__ == "__main__": run()