graph-of-thoughts/graph_of_thoughts/api/got_openai_pipeline.py
2026-03-19 23:18:12 +00:00

124 lines
4.6 KiB
Python

# Copyright (c) 2023 ETH Zurich.
# All rights reserved.
#
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
from __future__ import annotations
import json
import re
from typing import Any, Dict, List, Union
from graph_of_thoughts.operations import GraphOfOperations, operations
from graph_of_thoughts.operations.thought import Thought
from graph_of_thoughts.parser import Parser
from graph_of_thoughts.prompter import Prompter
def format_chat_messages(messages: List[Dict[str, str]]) -> str:
parts: List[str] = []
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
if not isinstance(content, str):
content = str(content)
parts.append(f"{role.upper()}:\n{content}")
return "\n\n".join(parts)
class ChatCompletionPrompter(Prompter):
"""Prompter for a small generate → score → keep-best Graph of Operations."""
def generate_prompt(self, num_branches: int, **kwargs: Any) -> str:
problem = kwargs.get("input", "")
return (
"You are a careful assistant. Read the conversation below and produce "
"one candidate answer for the USER's latest needs.\n\n"
f"{problem}\n\n"
"Reply with your answer only, no preamble."
)
def score_prompt(self, state_dicts: List[Dict], **kwargs: Any) -> str:
lines = [
"You evaluate candidate answers for the same problem. "
"Score each candidate from 0 (worst) to 10 (best) on correctness, "
"completeness, and relevance.",
"",
"Return ONLY a JSON array of numbers, one score per candidate in order, e.g. [7, 5, 9].",
"",
]
for i, st in enumerate(state_dicts):
cand = st.get("candidate", "")
lines.append(f"Candidate {i}:\n{cand}\n")
return "\n".join(lines)
def aggregation_prompt(self, state_dicts: List[Dict], **kwargs: Any) -> str:
raise RuntimeError("aggregation_prompt is not used by the chat completion pipeline")
def improve_prompt(self, **kwargs: Any) -> str:
raise RuntimeError("improve_prompt is not used by the chat completion pipeline")
def validation_prompt(self, **kwargs: Any) -> str:
raise RuntimeError("validation_prompt is not used by the chat completion pipeline")
class ChatCompletionParser(Parser):
def parse_generate_answer(self, state: Dict, texts: List[str]) -> List[Dict]:
out: List[Dict] = []
for i, t in enumerate(texts):
out.append({"candidate": (t or "").strip(), "branch_index": i})
return out
def parse_score_answer(self, states: List[Dict], texts: List[str]) -> List[float]:
raw = texts[0] if texts else ""
scores = self._scores_from_text(raw, len(states))
if len(scores) < len(states):
scores.extend([0.0] * (len(states) - len(scores)))
return scores[: len(states)]
def _scores_from_text(self, raw: str, n: int) -> List[float]:
raw = raw.strip()
try:
data = json.loads(raw)
if isinstance(data, list):
return [float(x) for x in data]
except (json.JSONDecodeError, ValueError, TypeError):
pass
nums = re.findall(r"-?\d+(?:\.\d+)?", raw)
return [float(x) for x in nums[:n]]
def parse_aggregation_answer(
self, states: List[Dict], texts: List[str]
) -> Union[Dict, List[Dict]]:
raise RuntimeError("parse_aggregation_answer is not used")
def parse_improve_answer(self, state: Dict, texts: List[str]) -> Dict:
raise RuntimeError("parse_improve_answer is not used")
def parse_validation_answer(self, state: Dict, texts: List[str]) -> bool:
raise RuntimeError("parse_validation_answer is not used")
def build_default_chat_graph(num_candidates: int = 3) -> GraphOfOperations:
g = GraphOfOperations()
g.append_operation(
operations.Generate(
num_branches_prompt=1, num_branches_response=num_candidates
)
)
g.append_operation(operations.Score(combined_scoring=True))
g.append_operation(operations.KeepBestN(1))
return g
def extract_assistant_text(final_thoughts_list: List[List[Thought]]) -> str:
"""``get_final_thoughts`` returns one list per leaf operation; we take the first leaf's first thought."""
if not final_thoughts_list:
return ""
thoughts = final_thoughts_list[0]
if not thoughts:
return ""
state = thoughts[0].state or {}
return str(state.get("candidate", ""))