# Copyright (c) 2023 ETH Zurich. # All rights reserved. # # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. from __future__ import annotations import json import re from typing import Any, Dict, List, Union from graph_of_thoughts.operations import GraphOfOperations, operations from graph_of_thoughts.operations.thought import Thought from graph_of_thoughts.parser import Parser from graph_of_thoughts.prompter import Prompter def format_chat_messages(messages: List[Dict[str, str]]) -> str: parts: List[str] = [] for m in messages: role = m.get("role", "user") content = m.get("content", "") if not isinstance(content, str): content = str(content) parts.append(f"{role.upper()}:\n{content}") return "\n\n".join(parts) class ChatCompletionPrompter(Prompter): """Prompter for a small generate → score → keep-best Graph of Operations.""" def generate_prompt(self, num_branches: int, **kwargs: Any) -> str: problem = kwargs.get("input", "") return ( "You are a careful assistant. Read the conversation below and produce " "one candidate answer for the USER's latest needs.\n\n" f"{problem}\n\n" "Reply with your answer only, no preamble." ) def score_prompt(self, state_dicts: List[Dict], **kwargs: Any) -> str: lines = [ "You evaluate candidate answers for the same problem. " "Score each candidate from 0 (worst) to 10 (best) on correctness, " "completeness, and relevance.", "", "Return ONLY a JSON array of numbers, one score per candidate in order, e.g. [7, 5, 9].", "", ] for i, st in enumerate(state_dicts): cand = st.get("candidate", "") lines.append(f"Candidate {i}:\n{cand}\n") return "\n".join(lines) def aggregation_prompt(self, state_dicts: List[Dict], **kwargs: Any) -> str: raise RuntimeError("aggregation_prompt is not used by the chat completion pipeline") def improve_prompt(self, **kwargs: Any) -> str: raise RuntimeError("improve_prompt is not used by the chat completion pipeline") def validation_prompt(self, **kwargs: Any) -> str: raise RuntimeError("validation_prompt is not used by the chat completion pipeline") class ChatCompletionParser(Parser): def parse_generate_answer(self, state: Dict, texts: List[str]) -> List[Dict]: out: List[Dict] = [] for i, t in enumerate(texts): out.append({"candidate": (t or "").strip(), "branch_index": i}) return out def parse_score_answer(self, states: List[Dict], texts: List[str]) -> List[float]: raw = texts[0] if texts else "" scores = self._scores_from_text(raw, len(states)) if len(scores) < len(states): scores.extend([0.0] * (len(states) - len(scores))) return scores[: len(states)] def _scores_from_text(self, raw: str, n: int) -> List[float]: raw = raw.strip() try: data = json.loads(raw) if isinstance(data, list): return [float(x) for x in data] except (json.JSONDecodeError, ValueError, TypeError): pass nums = re.findall(r"-?\d+(?:\.\d+)?", raw) return [float(x) for x in nums[:n]] def parse_aggregation_answer( self, states: List[Dict], texts: List[str] ) -> Union[Dict, List[Dict]]: raise RuntimeError("parse_aggregation_answer is not used") def parse_improve_answer(self, state: Dict, texts: List[str]) -> Dict: raise RuntimeError("parse_improve_answer is not used") def parse_validation_answer(self, state: Dict, texts: List[str]) -> bool: raise RuntimeError("parse_validation_answer is not used") def build_default_chat_graph(num_candidates: int = 3) -> GraphOfOperations: g = GraphOfOperations() g.append_operation( operations.Generate( num_branches_prompt=1, num_branches_response=num_candidates ) ) g.append_operation(operations.Score(combined_scoring=True)) g.append_operation(operations.KeepBestN(1)) return g def extract_assistant_text(final_thoughts_list: List[List[Thought]]) -> str: """``get_final_thoughts`` returns one list per leaf operation; we take the first leaf's first thought.""" if not final_thoughts_list: return "" thoughts = final_thoughts_list[0] if not thoughts: return "" state = thoughts[0].state or {} return str(state.get("candidate", ""))