2023-12-02 17:37:03 +01:00

768 lines
28 KiB
Python

# Copyright (c) 2023 ETH Zurich.
# All rights reserved.
#
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
#
# main author: Nils Blach
import os
import re
import logging
import datetime
import json
import csv
from statistics import fmean
from typing import Dict, List, Callable, Set, Union
from graph_of_thoughts import controller, language_models, operations, prompter, parser
class DocMergePrompter(prompter.Prompter):
"""
DocMergePrompter provides the generation of prompts specific to the document
merge example for the language models.
Inherits from the Prompter class and implements its abstract methods.
"""
merge_doc_prompt_start = """Merge the following {num} NDA documents <Doc1> - <Doc{num}> into a single NDA, maximizing retained information and minimizing redundancy. Output only the created NDA between the tags <Merged> and </Merged>, without any additional text.
Here are NDAs <Doc1> - <Doc{num}>
"""
merge_doc_prompt_block = """
<Doc{num}>
{document}
</Doc{num}>
"""
merge_doc_prompt_cot_start = """Merge the following {num} NDA documents <Doc1> - <Doc{num}> into a single NDA, maximizing retained information and minimizing redundancy.
You can generate any intermediate thoughts and documents you want, but the final output should be the merged NDA, placed between the two tags <Merged> and </Merged>.
For instance you might want to follow this approach:
1. Split each NDA into their logical subparts.
2. Merge the subparts of the {num} NDAs.
3. Combine the merged subparts into a single NDA.
4. Place the merged NDA between the tags <Merged> and </Merged>.
Here are NDAs <Doc1> - <Doc{num}>:
"""
improve_summary_prompt_start = """The following NDA <S> merges initial NDAs <Doc1> - <Doc{num}>.
Please improve the summary NDA <S> by adding more information and removing redundancy. Output only the improved NDA, placed between the two tags <Merged> and </Merged>, without any additional text.
Here are NDAs <Doc1> - <Doc{num}>:
"""
improve_summary_prompt_block = """
<Doc{num}>
{document}
</Doc{num}>
"""
improve_summary_prompt_end = """
Here is the summary NDA <S>:
<S>
{summary}
</S>
"""
score_prompt_base = """The following NDA <S> merges NDAs <Doc1> - <Doc{num}>.
Please score the merged NDA <S> in terms of how much redundant information is contained, independent of the original NDAs, as well as how much information is retained from the original NDAs.
A score of 10 for redundancy implies that absolutely no information is redundant, while a score of 0 implies that at least half of the information is redundant (so everything is at least mentioned twice).
A score of 10 for retained information implies that all information from the original NDAs is retained, while a score of 0 implies that no information is retained.
You may provide reasoning for your scoring, but the final score for redundancy should be between the tags <Redundancy> and </Redundancy>, and the final score for retained information should be between the tags <Retained> and </Retained>, without any additional text within any of those tags.
Here are NDAs <Doc1> - <Doc{num}>:
"""
score_prompt_block = """
<Doc{num}>
{document}
</Doc{num}>
"""
score_prompt_end = """
Here is the summary NDA <S>:
<S>
{summary}
</S>
"""
aggregate_full_prompt_base = """The following NDAs <S1> - <S{num_ndas_summary}> each merge the initial NDAs <Doc1> - <Doc{num_ndas}>.
Combine the merged NDAs <S1> - <S{num_ndas_summary}> into a new one, maximizing their advantages and overall information retention, while minimizing redundancy.
Output only the new NDA between the tags <Merged> and </Merged>, without any additional text.
Here are the original NDAs <Doc1> - <Doc{num_ndas}>:
"""
aggregate_full_prompt_block1 = """
<Doc{num}>
{document}
</Doc{num}>
"""
aggregate_full_prompt_mid = """
Here are the summary NDAs <S1> - <S{num_ndas_summary}>:
"""
aggregate_full_prompt_block2 = """
<S{num}>
{summary}
</S{num}>
"""
aggregate_sub_prompt_base = """The following NDAs <S1> - <S{num_ndas}> are summaries of some other NDAs.
Combine them into a new one, make sure to maximize their advantages and overall information retention, while minimizing redundancy.
Output only the new NDA between the tags <Merged> and </Merged>, without any additional text.
Here are NDAs <S1> - <S{num_ndas}>:
"""
aggregate_sub_prompt_generate = """
NDA <S{num}>:
{nda}
</S{num}>
"""
def aggregation_prompt(self, state_dicts: List[Dict], **kwargs) -> str:
"""
Generate an aggregation prompt for the language model.
:param state_dicts: The thought states that should be aggregated.
:type state_dicts: List[Dict]
:param kwargs: Additional keyword arguments.
:return: The aggregation prompt.
:rtype: str
"""
if len(state_dicts[0]["parts"]) > 0 and len(state_dicts[0]["parts"]) < len(
state_dicts[0]["documents"]
):
prompt = self.aggregate_sub_prompt_base.format(
num_ndas=len(state_dicts),
)
for i, state_dict in enumerate(state_dicts):
prompt += self.aggregate_sub_prompt_generate.format(
nda=state_dict["current"], num=i + 1
)
return prompt
else:
prompt = self.aggregate_full_prompt_base.format(
num_ndas=len(state_dicts[0]["documents"]),
num_ndas_summary=len(state_dicts),
)
for i, document in enumerate(state_dicts[0]["documents"]):
prompt += self.aggregate_full_prompt_block1.format(
document=document, num=i + 1
)
prompt += self.aggregate_full_prompt_mid.format(
num_ndas_summary=len(state_dicts),
)
for i, state_dict in enumerate(state_dicts):
prompt += self.aggregate_full_prompt_block2.format(
summary=state_dict["current"], num=i + 1
)
return prompt
def generate_prompt(
self,
num_branches: int,
documents: List[str],
method: str,
parts: Set[str],
current: str,
**kwargs,
) -> str:
"""
Generate a generate prompt for the language model.
:param num_branches: The number of responses the prompt should ask the LM to generate.
:type num_branches: int
:param documents: The list of documents to be merged.
:type documents: List[str]
:param method: Method for which the generate prompt is generated.
:type method: str
:param parts: Indices of the already processed document parts.
:type parts: Set[str]
:param current: The intermediate solution.
:type current: str
:param kwargs: Additional keyword arguments.
:return: The generate prompt.
:rtype: str
:raise AssertionError: If method is not implemented yet.
"""
prompt = ""
if method.startswith("io") or method.startswith("cot"):
if method.startswith("io"):
prompt += self.merge_doc_prompt_start.format(num=len(documents))
else:
prompt += self.merge_doc_prompt_cot_start.format(num=len(documents))
for i, document in enumerate(documents):
prompt += self.merge_doc_prompt_block.format(
document=document, num=i + 1
)
return prompt
elif method.startswith("tot"):
if current is None or current == "":
prompt += self.merge_doc_prompt_start.format(num=len(documents))
for i, document in enumerate(documents):
prompt += self.merge_doc_prompt_block.format(
document=document, num=i + 1
)
return prompt
else:
prompt += self.improve_summary_prompt_start.format(
num=len(documents),
)
for i, document in enumerate(documents):
prompt += self.improve_summary_prompt_block.format(
document=document, num=i + 1
)
prompt += self.improve_summary_prompt_end.format(summary=current)
return prompt
elif method.startswith("got"):
parts = (
sorted(list(parts)) if len(parts) > 0 else list(range(len(documents)))
)
if current is None or current == "":
prompt += self.merge_doc_prompt_start.format(num=len(parts))
for i, part in enumerate(sorted(list(parts))):
prompt += self.merge_doc_prompt_block.format(
document=documents[part], num=i + 1
)
return prompt
else:
prompt += self.improve_summary_prompt_start.format(
num=len(parts),
)
for i, part in enumerate(sorted(list(parts))):
prompt += self.improve_summary_prompt_block.format(
document=documents[part], num=i + 1
)
prompt += self.improve_summary_prompt_end.format(summary=current)
return prompt
else:
assert False, "Not implemented yet."
def score_prompt(self, state_dicts: List[Dict], **kwargs) -> str:
"""
Generate a score prompt for the language model.
:param state_dicts: The thought states that should be scored,
if more than one, they should be scored together.
:type state_dicts: List[Dict]
:param kwargs: Additional keyword arguments.
:return: The score prompt.
:rtype: str
:raise AssertionError: If more than one thought state is supplied.
"""
if len(state_dicts) > 1:
assert False, "Not implemented yet."
else:
# perform individual scoring
parts = (
[
state_dicts[0]["documents"][part]
for part in sorted(list(state_dicts[0]["parts"]))
]
if len(state_dicts[0]["parts"]) > 0
else state_dicts[0]["documents"]
)
prompt = self.score_prompt_base.format(
num=len(parts),
)
for i, part in enumerate(parts):
prompt += self.score_prompt_block.format(document=part, num=i + 1)
prompt += self.score_prompt_end.format(
summary=state_dicts[0]["current"],
)
return prompt
def improve_prompt(self, **kwargs) -> str:
"""
Generate an improve prompt for the language model.
:param kwargs: Additional keyword arguments.
:return: The improve prompt.
:rtype: str
"""
pass
def validation_prompt(self, **kwargs) -> str:
"""
Generate a validation prompt for the language model.
:param kwargs: Additional keyword arguments.
:return: The validation prompt.
:rtype: str
"""
pass
class DocMergeParser(parser.Parser):
"""
DocMergeParser provides the parsing of language model reponses specific to the
document merge example.
Inherits from the Parser class and implements its abstract methods.
"""
def __init__(self) -> None:
"""
Inits the response cache.
"""
self.cache = {}
def strip_answer_helper(self, text: str, tag: str = "") -> str:
"""
Helper function to remove tags from a text.
:param text: The input text.
:type text: str
:param tag: The tag to be stripped. Defaults to "".
:type tag: str
:return: The stripped text.
:rtype: str
"""
text = text.strip()
if "Output:" in text:
text = text[text.index("Output:") + len("Output:") :].strip()
if tag != "":
start = text.rfind(f"<{tag}>")
end = text.rfind(f"</{tag}>")
if start != -1 and end != -1:
text = text[start + len(f"<{tag}>") : end].strip()
elif start != -1:
logging.warning(
f"Only found the start tag <{tag}> in answer: {text}. Returning everything after the tag."
)
text = text[start + len(f"<{tag}>") :].strip()
elif end != -1:
logging.warning(
f"Only found the end tag </{tag}> in answer: {text}. Returning everything before the tag."
)
text = text[:end].strip()
else:
logging.warning(
f"Could not find any tag {tag} in answer: {text}. Returning the full answer."
)
return text
def parse_aggregation_answer(
self, states: List[Dict], texts: List[str]
) -> Union[Dict, List[Dict]]:
"""
Parse the response from the language model for an aggregation prompt.
:param states: The thought states used to generate the prompt.
:type states: List[Dict]
:param texts: The responses to the prompt from the language model.
:type texts: List[str]
:return: The new thought states after parsing the respones from the language model.
:rtype: Union[Dict, List[Dict]]
"""
new_states = []
for text in texts:
if len(states[0]["parts"]) < len(states[0]["documents"]):
# subpart aggregation
text = self.strip_answer_helper(text, "Merged")
new_state = states[0].copy()
new_state["current"] = text
new_state["parts"] = set()
for state in states:
new_state["parts"] = new_state["parts"] | state["parts"]
new_states.append(new_state)
else:
# full NDA aggregation
text = self.strip_answer_helper(text, "Merged")
new_state = states[0].copy()
new_state["current"] = text
new_states.append(new_state)
return new_states
def parse_generate_answer(self, state: Dict, texts: List[str]) -> List[Dict]:
"""
Parse the response from the language model for a generate prompt.
:param state: The thought state used to generate the prompt.
:type state: Dict
:param texts: The responses to the prompt from the language model.
:type texts: List[str]
:return: The new thought states after parsing the respones from the language model.
:rtype: List[Dict]
"""
new_states = []
for text in texts:
text = self.strip_answer_helper(text, "Merged")
new_state = state.copy()
new_state["current"] = text
new_states.append(new_state)
return new_states
def parse_score_answer(self, states: List[Dict], texts: List[str]) -> List[float]:
"""
Parse the response from the language model for a score prompt.
:param states: The thought states used to generate the prompt.
:type states: List[Dict]
:param texts: The responses to the prompt from the language model.
:type texts: List[str]
:return: The scores for the thought states.
:rtype: List[float]
:raise AssertionError: If the number of thought states is not one.
"""
assert len(states) == 1, "Only one state is allowed for scoring."
if len(states) == 1:
# individual scoring
redundancy_scores = []
retain_scores = []
for text in texts:
answer = self.strip_answer_helper(text, "Redundancy")
res = re.findall(r"\d+\.?\d*", answer)
if len(res) == 1:
redundancy_scores.append(float(res[0]))
elif len(res) > 1:
logging.warning(
f"Found multiple redundancy scores in answer: {text}. Returning the last one."
)
redundancy_scores.append(float(res[-1]))
else:
logging.warning(
f"Could not find any redundancy score in answer: {text}. Ignoring this answer."
)
answer = self.strip_answer_helper(text, "Retained")
res = re.findall(r"\d+\.?\d*", answer)
if len(res) == 1:
retain_scores.append(float(res[0]))
elif len(res) > 1:
logging.warning(
f"Found multiple retained scores in answer: {text}. Returning the last one."
)
retain_scores.append(float(res[-1]))
else:
logging.warning(
f"Could not find any retained score in answer: {text}. Ignoring this answer."
)
if len(redundancy_scores) == 0 or len(retain_scores) == 0:
logging.warning(
f"Could not find any valid score in any answer. Returning 0.0."
)
return [0.0]
mean_redundancy = fmean(redundancy_scores)
mean_retain = fmean(retain_scores)
f1 = 2 * mean_redundancy * mean_retain / (mean_redundancy + mean_retain)
return [f1]
def parse_improve_answer(self, state: Dict, texts: List[str]) -> Dict:
"""
Parse the response from the language model for an improve prompt.
:param state: The thought state used to generate the prompt.
:type state: Dict
:param texts: The responses to the prompt from the language model.
:type texts: List[str]
:return: The new thought state after parsing the responses from the language model.
:rtype: Dict
"""
pass
def parse_validation_answer(self, state: Dict, texts: List[str]) -> bool:
"""
Parse the response from the language model for a validation prompt.
:param state: The thought state used to generate the prompt.
:type state: Dict
:param texts: The responses to the prompt from the language model.
:type texts: List[str]
:return: Whether the thought state is valid or not.
:rtype: bool
"""
pass
def io() -> operations.GraphOfOperations:
"""
Generates the Graph of Operations for the IO method.
:return: Graph of Operations
:rtype: GraphOfOperations
"""
operations_graph = operations.GraphOfOperations()
operations_graph.append_operation(operations.Generate(1, 1))
operations_graph.append_operation(operations.Score(3, False))
return operations_graph
def cot() -> operations.GraphOfOperations:
"""
Generates the Graph of Operations for the CoT method.
:return: Graph of Operations
:rtype: GraphOfOperations
"""
operations_graph = operations.GraphOfOperations()
operations_graph.append_operation(operations.Generate(1, 1))
operations_graph.append_operation(operations.Score(3, False))
return operations_graph
def tot() -> operations.GraphOfOperations:
"""
Generates the Graph of Operations for the ToT method.
:return: Graph of Operations
:rtype: GraphOfOperations
"""
operations_graph = operations.GraphOfOperations()
branch_factor = 10
operations_graph.append_operation(operations.Generate(1, branch_factor))
operations_graph.append_operation(operations.Score(3, False))
keep_best_1 = operations.KeepBestN(1, True)
operations_graph.append_operation(keep_best_1)
for _ in range(2):
operations_graph.append_operation(operations.Generate(1, branch_factor))
operations_graph.append_operation(operations.Score(3, False))
keep_best_2 = operations.KeepBestN(1, True)
keep_best_2.add_predecessor(keep_best_1)
operations_graph.append_operation(keep_best_2)
keep_best_1 = keep_best_2
return operations_graph
def got() -> operations.GraphOfOperations:
"""
Generates the Graph of Operations for the GoT method, where full documents
are merged.
:return: Graph of Operations
:rtype: GraphOfOperations
"""
operations_graph = operations.GraphOfOperations()
operations_graph.append_operation(operations.Generate(1, 5))
operations_graph.append_operation(operations.Score(3, False))
keep_best = operations.KeepBestN(3, True)
operations_graph.append_operation(keep_best)
operations_graph.append_operation(operations.Aggregate(5))
operations_graph.append_operation(operations.Score(3, False))
keep_best2 = operations.KeepBestN(1, True)
keep_best2.add_predecessor(keep_best)
operations_graph.append_operation(keep_best2)
operations_graph.append_operation(operations.Generate(1, 10))
operations_graph.append_operation(operations.Score(3, False))
keep_best3 = operations.KeepBestN(1, True)
keep_best3.add_predecessor(keep_best2)
operations_graph.append_operation(keep_best3)
return operations_graph
def got2() -> operations.GraphOfOperations:
"""
Generates the Graph of Operations for the GoT2 method, where partial
documents are merged.
:return: Graph of Operations
:rtype: GraphOfOperations
"""
operations_graph = operations.GraphOfOperations()
sub_parts = []
for i in range(0, 4, 2): # should be at most 16 parts
sub_text = operations.Selector(
lambda thoughts, list_id=i: [
operations.Thought(
state={**thoughts[0].state, "parts": {list_id, list_id + 1}}
)
]
)
operations_graph.add_operation(sub_text)
gen_nda = operations.Generate(1, 5)
gen_nda.add_predecessor(sub_text)
operations_graph.add_operation(gen_nda)
score_nda = operations.Score(3, False)
score_nda.add_predecessor(gen_nda)
operations_graph.add_operation(score_nda)
keep_best_nda = operations.KeepBestN(1, True)
keep_best_nda.add_predecessor(score_nda)
operations_graph.add_operation(keep_best_nda)
sub_parts.append(keep_best_nda)
while len(sub_parts) > 1:
new_sub_parts = []
for i in range(0, len(sub_parts), 2):
if i + 1 == len(sub_parts):
new_sub_parts.append(sub_parts[i])
continue
aggregate = operations.Aggregate(5)
aggregate.add_predecessor(sub_parts[i])
aggregate.add_predecessor(sub_parts[i + 1])
operations_graph.add_operation(aggregate)
score = operations.Score(3, False)
score.add_predecessor(aggregate)
operations_graph.add_operation(score)
keep_best = operations.KeepBestN(1, True)
keep_best.add_predecessor(score)
operations_graph.add_operation(keep_best)
gen_nda = operations.Generate(1, 5)
gen_nda.add_predecessor(keep_best)
operations_graph.add_operation(gen_nda)
score_nda = operations.Score(3, False)
score_nda.add_predecessor(gen_nda)
operations_graph.add_operation(score_nda)
keep_best_nda = operations.KeepBestN(1, True)
keep_best_nda.add_predecessor(score_nda)
keep_best_nda.add_predecessor(keep_best)
operations_graph.add_operation(keep_best_nda)
new_sub_parts.append(keep_best_nda)
sub_parts = new_sub_parts
return operations_graph
def run(
data_ids: List[int],
methods: List[Callable[[], operations.GraphOfOperations]],
budget: float,
lm_name: str,
) -> float:
"""
Controller function that executes each specified method for each specified
sample while the budget is not exhausted.
:param data_ids: Indices of the sample to be run.
:type data_ids: List[int]
:param methods: List of functions to generate Graphs of Operations.
:type methods: Each function generates a Graph of Operation.
:param budget: Language model budget for the execution in dollars.
:type budget: float
:param lm_name: Name of the language model to be used.
:type lm_name: str
:return: Spent budget in dollars.
:rtype: float
"""
orig_budget = budget
data_path = os.path.join(os.path.dirname(__file__), "documents.csv")
data = []
with open(data_path, "r", encoding="utf8") as f:
reader = csv.reader(f)
next(reader)
for row in reader:
row[0] = int(row[0])
data.append(row)
if data_ids is None or len(data_ids) == 0:
data_ids = list(range(len(data)))
selected_data = [data[i] for i in data_ids]
results_dir = os.path.join(os.path.dirname(__file__), "results")
if not os.path.exists(results_dir):
os.makedirs(results_dir)
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
extra_info = f"{lm_name}_{'-'.join([method.__name__ for method in methods])}"
folder_name = f"{extra_info}_{timestamp}"
results_folder = os.path.join(results_dir, folder_name)
os.makedirs(results_folder)
config = {
"data": selected_data,
"methods": [method.__name__ for method in methods],
"lm": lm_name,
"budget": budget,
}
with open(os.path.join(results_folder, "config.json"), "w") as f:
json.dump(config, f)
logging.basicConfig(
filename=os.path.join(results_folder, "log.log"),
filemode="w",
format="%(name)s - %(levelname)s - %(message)s",
level=logging.DEBUG,
)
for method in methods:
os.makedirs(os.path.join(results_folder, method.__name__))
for data in selected_data:
logging.info(f"Running data {data[0]}: {data[1]}")
if budget <= 0.0:
logging.error(
f"Budget has been depleted, stopping. Data {data[0]} has not been run."
)
break
for method in methods:
logging.info(f"Running method {method.__name__}")
logging.info(f"Budget left: {budget}")
if budget <= 0.0:
logging.error(
f"Budget has been depleted, stopping. Method {method.__name__} has not been run."
)
break
lm = language_models.ChatGPT(
os.path.join(
os.path.dirname(__file__),
"../../graph_of_thoughts/language_models/config.json",
),
model_name=lm_name,
cache=True,
)
operations_graph = method()
executor = controller.Controller(
lm,
operations_graph,
DocMergePrompter(),
DocMergeParser(),
{
"documents": [data[2], data[3], data[4], data[5]],
"parts": set(),
"current": "",
"method": method.__name__,
},
)
try:
executor.run()
except Exception as e:
logging.error(f"Exception: {e}")
path = os.path.join(
results_folder,
method.__name__,
f"{data[0]}.json",
)
for operation in operations_graph.operations:
for thought in operation.thoughts:
thought.state["parts"] = list(thought.state["parts"])
executor.output_graph(path)
budget -= lm.cost
return orig_budget - budget
if __name__ == "__main__":
"""
Input (x1, x2, x3, x4): Four NDAs
Output (y): A new combined NDA
Evaluation: According to information coverage without repetition (scored by the LLM)
"""
budget = 30
samples = [item for item in range(0, 50)]
approaches = [io, cot, tot, got, got2]
spent = run(samples, approaches, budget, "chatgpt")
logging.info(f"Spent {spent} out of {budget} budget.")