153 lines
6.0 KiB
Python
153 lines
6.0 KiB
Python
# Copyright (c) 2023 ETH Zurich.
|
|
# All rights reserved.
|
|
#
|
|
# Use of this source code is governed by a BSD-style license that can be
|
|
# found in the LICENSE file.
|
|
#
|
|
# main author: Nils Blach
|
|
|
|
import json
|
|
import logging
|
|
from typing import List
|
|
from graph_of_thoughts.language_models import AbstractLanguageModel
|
|
from graph_of_thoughts.operations import GraphOfOperations, Thought
|
|
from graph_of_thoughts.prompter import Prompter
|
|
from graph_of_thoughts.parser import Parser
|
|
|
|
|
|
class Controller:
|
|
"""
|
|
Controller class to manage the execution flow of the Graph of Operations,
|
|
generating the Graph Reasoning State.
|
|
This involves language models, graph operations, prompting, and parsing.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
lm: AbstractLanguageModel,
|
|
graph: GraphOfOperations,
|
|
prompter: Prompter,
|
|
parser: Parser,
|
|
problem_parameters: dict,
|
|
) -> None:
|
|
"""
|
|
Initialize the Controller instance with the language model,
|
|
operations graph, prompter, parser, and problem parameters.
|
|
|
|
:param lm: An instance of the AbstractLanguageModel.
|
|
:type lm: AbstractLanguageModel
|
|
:param graph: The Graph of Operations to be executed.
|
|
:type graph: OperationsGraph
|
|
:param prompter: An instance of the Prompter class, used to generate prompts.
|
|
:type prompter: Prompter
|
|
:param parser: An instance of the Parser class, used to parse responses.
|
|
:type parser: Parser
|
|
:param problem_parameters: Initial parameters/state of the problem.
|
|
:type problem_parameters: dict
|
|
"""
|
|
self.logger = logging.getLogger(self.__class__.__module__)
|
|
self.lm = lm
|
|
self.graph = graph
|
|
self.prompter = prompter
|
|
self.parser = parser
|
|
self.problem_parameters = problem_parameters
|
|
self.run_executed = False
|
|
|
|
def run(self) -> None:
|
|
"""
|
|
Run the controller and execute the operations from the Graph of
|
|
Operations based on their readiness.
|
|
Ensures the program is in a valid state before execution.
|
|
:raises AssertionError: If the Graph of Operation has no roots.
|
|
:raises AssertionError: If the successor of an operation is not in the Graph of Operations.
|
|
"""
|
|
self.logger.debug("Checking that the program is in a valid state")
|
|
assert self.graph.roots is not None, "The operations graph has no root"
|
|
self.logger.debug("The program is in a valid state")
|
|
|
|
execution_queue = [
|
|
operation
|
|
for operation in self.graph.operations
|
|
if operation.can_be_executed()
|
|
]
|
|
|
|
while len(execution_queue) > 0:
|
|
current_operation = execution_queue.pop(0)
|
|
self.logger.info("Executing operation %s", current_operation.operation_type)
|
|
current_operation.execute(
|
|
self.lm, self.prompter, self.parser, **self.problem_parameters
|
|
)
|
|
self.logger.info("Operation %s executed", current_operation.operation_type)
|
|
for operation in current_operation.successors:
|
|
assert (
|
|
operation in self.graph.operations
|
|
), "The successor of an operation is not in the operations graph"
|
|
if operation.can_be_executed():
|
|
execution_queue.append(operation)
|
|
self.logger.info("All operations executed")
|
|
self.run_executed = True
|
|
|
|
def get_final_thoughts(self) -> List[List[Thought]]:
|
|
"""
|
|
Retrieve the final thoughts after all operations have been executed.
|
|
|
|
:return: List of thoughts for each operation in the graph's leaves.
|
|
:rtype: List[List[Thought]]
|
|
:raises AssertionError: If the `run` method hasn't been executed yet.
|
|
"""
|
|
assert self.run_executed, "The run method has not been executed"
|
|
return [operation.get_thoughts() for operation in self.graph.leaves]
|
|
|
|
def output_graph(self, path: str) -> None:
|
|
"""
|
|
Serialize the state and results of the operations graph to a JSON file.
|
|
|
|
:param path: The path to the output file.
|
|
:type path: str
|
|
"""
|
|
output = []
|
|
for operation in self.graph.operations:
|
|
operation_serialized = {
|
|
"operation": operation.operation_type.name,
|
|
"thoughts": [thought.state for thought in operation.get_thoughts()],
|
|
}
|
|
if any([thought.scored for thought in operation.get_thoughts()]):
|
|
operation_serialized["scored"] = [
|
|
thought.scored for thought in operation.get_thoughts()
|
|
]
|
|
operation_serialized["scores"] = [
|
|
thought.score for thought in operation.get_thoughts()
|
|
]
|
|
if any([thought.validated for thought in operation.get_thoughts()]):
|
|
operation_serialized["validated"] = [
|
|
thought.validated for thought in operation.get_thoughts()
|
|
]
|
|
operation_serialized["validity"] = [
|
|
thought.valid for thought in operation.get_thoughts()
|
|
]
|
|
if any(
|
|
[
|
|
thought.compared_to_ground_truth
|
|
for thought in operation.get_thoughts()
|
|
]
|
|
):
|
|
operation_serialized["compared_to_ground_truth"] = [
|
|
thought.compared_to_ground_truth
|
|
for thought in operation.get_thoughts()
|
|
]
|
|
operation_serialized["problem_solved"] = [
|
|
thought.solved for thought in operation.get_thoughts()
|
|
]
|
|
output.append(operation_serialized)
|
|
|
|
output.append(
|
|
{
|
|
"prompt_tokens": self.lm.prompt_tokens,
|
|
"completion_tokens": self.lm.completion_tokens,
|
|
"cost": self.lm.cost,
|
|
}
|
|
)
|
|
|
|
with open(path, "w") as file:
|
|
file.write(json.dumps(output, indent=2))
|