124 lines
4.6 KiB
Python
124 lines
4.6 KiB
Python
# Copyright (c) 2023 ETH Zurich.
|
|
# All rights reserved.
|
|
#
|
|
# Use of this source code is governed by a BSD-style license that can be
|
|
# found in the LICENSE file.
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from typing import Any, Dict, List, Union
|
|
|
|
from graph_of_thoughts.operations import GraphOfOperations, operations
|
|
from graph_of_thoughts.operations.thought import Thought
|
|
from graph_of_thoughts.parser import Parser
|
|
from graph_of_thoughts.prompter import Prompter
|
|
|
|
|
|
def format_chat_messages(messages: List[Dict[str, str]]) -> str:
|
|
parts: List[str] = []
|
|
for m in messages:
|
|
role = m.get("role", "user")
|
|
content = m.get("content", "")
|
|
if not isinstance(content, str):
|
|
content = str(content)
|
|
parts.append(f"{role.upper()}:\n{content}")
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
class ChatCompletionPrompter(Prompter):
|
|
"""Prompter for a small generate → score → keep-best Graph of Operations."""
|
|
|
|
def generate_prompt(self, num_branches: int, **kwargs: Any) -> str:
|
|
problem = kwargs.get("input", "")
|
|
return (
|
|
"You are a careful assistant. Read the conversation below and produce "
|
|
"one candidate answer for the USER's latest needs.\n\n"
|
|
f"{problem}\n\n"
|
|
"Reply with your answer only, no preamble."
|
|
)
|
|
|
|
def score_prompt(self, state_dicts: List[Dict], **kwargs: Any) -> str:
|
|
lines = [
|
|
"You evaluate candidate answers for the same problem. "
|
|
"Score each candidate from 0 (worst) to 10 (best) on correctness, "
|
|
"completeness, and relevance.",
|
|
"",
|
|
"Return ONLY a JSON array of numbers, one score per candidate in order, e.g. [7, 5, 9].",
|
|
"",
|
|
]
|
|
for i, st in enumerate(state_dicts):
|
|
cand = st.get("candidate", "")
|
|
lines.append(f"Candidate {i}:\n{cand}\n")
|
|
return "\n".join(lines)
|
|
|
|
def aggregation_prompt(self, state_dicts: List[Dict], **kwargs: Any) -> str:
|
|
raise RuntimeError("aggregation_prompt is not used by the chat completion pipeline")
|
|
|
|
def improve_prompt(self, **kwargs: Any) -> str:
|
|
raise RuntimeError("improve_prompt is not used by the chat completion pipeline")
|
|
|
|
def validation_prompt(self, **kwargs: Any) -> str:
|
|
raise RuntimeError("validation_prompt is not used by the chat completion pipeline")
|
|
|
|
|
|
class ChatCompletionParser(Parser):
|
|
def parse_generate_answer(self, state: Dict, texts: List[str]) -> List[Dict]:
|
|
out: List[Dict] = []
|
|
for i, t in enumerate(texts):
|
|
out.append({"candidate": (t or "").strip(), "branch_index": i})
|
|
return out
|
|
|
|
def parse_score_answer(self, states: List[Dict], texts: List[str]) -> List[float]:
|
|
raw = texts[0] if texts else ""
|
|
scores = self._scores_from_text(raw, len(states))
|
|
if len(scores) < len(states):
|
|
scores.extend([0.0] * (len(states) - len(scores)))
|
|
return scores[: len(states)]
|
|
|
|
def _scores_from_text(self, raw: str, n: int) -> List[float]:
|
|
raw = raw.strip()
|
|
try:
|
|
data = json.loads(raw)
|
|
if isinstance(data, list):
|
|
return [float(x) for x in data]
|
|
except (json.JSONDecodeError, ValueError, TypeError):
|
|
pass
|
|
nums = re.findall(r"-?\d+(?:\.\d+)?", raw)
|
|
return [float(x) for x in nums[:n]]
|
|
|
|
def parse_aggregation_answer(
|
|
self, states: List[Dict], texts: List[str]
|
|
) -> Union[Dict, List[Dict]]:
|
|
raise RuntimeError("parse_aggregation_answer is not used")
|
|
|
|
def parse_improve_answer(self, state: Dict, texts: List[str]) -> Dict:
|
|
raise RuntimeError("parse_improve_answer is not used")
|
|
|
|
def parse_validation_answer(self, state: Dict, texts: List[str]) -> bool:
|
|
raise RuntimeError("parse_validation_answer is not used")
|
|
|
|
|
|
def build_default_chat_graph(num_candidates: int = 3) -> GraphOfOperations:
|
|
g = GraphOfOperations()
|
|
g.append_operation(
|
|
operations.Generate(
|
|
num_branches_prompt=1, num_branches_response=num_candidates
|
|
)
|
|
)
|
|
g.append_operation(operations.Score(combined_scoring=True))
|
|
g.append_operation(operations.KeepBestN(1))
|
|
return g
|
|
|
|
|
|
def extract_assistant_text(final_thoughts_list: List[List[Thought]]) -> str:
|
|
"""``get_final_thoughts`` returns one list per leaf operation; we take the first leaf's first thought."""
|
|
if not final_thoughts_list:
|
|
return ""
|
|
thoughts = final_thoughts_list[0]
|
|
if not thoughts:
|
|
return ""
|
|
state = thoughts[0].state or {}
|
|
return str(state.get("candidate", ""))
|