2023-08-21 03:33:46 +02:00

153 lines
6.0 KiB
Python

# Copyright (c) 2023 ETH Zurich.
# All rights reserved.
#
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
#
# main author: Nils Blach
import json
import logging
from typing import List
from .abstract_language_model import AbstractLanguageModel
from graph_of_thoughts.operations import GraphOfOperations, Thought
from graph_of_thoughts.prompter import Prompter
from graph_of_thoughts.parser import Parser
class Controller:
"""
Controller class to manage the execution flow of the Graph of Operations,
generating the Graph Reasoning State.
This involves language models, graph operations, prompting, and parsing.
"""
def __init__(
self,
lm: AbstractLanguageModel,
graph: GraphOfOperations,
prompter: Prompter,
parser: Parser,
problem_parameters: dict,
) -> None:
"""
Initialize the Controller instance with the language model,
operations graph, prompter, parser, and problem parameters.
:param lm: An instance of the AbstractLanguageModel.
:type lm: AbstractLanguageModel
:param graph: The Graph of Operations to be executed.
:type graph: OperationsGraph
:param prompter: An instance of the Prompter class, used to generate prompts.
:type prompter: Prompter
:param parser: An instance of the Parser class, used to parse responses.
:type parser: Parser
:param problem_parameters: Initial parameters/state of the problem.
:type problem_parameters: dict
"""
self.logger = logging.getLogger(self.__class__.__module__)
self.lm = lm
self.graph = graph
self.prompter = prompter
self.parser = parser
self.problem_parameters = problem_parameters
self.run_executed = False
def run(self) -> None:
"""
Run the controller and execute the operations from the Graph of
Operations based on their readiness.
Ensures the program is in a valid state before execution.
:raises AssertionError: If the Graph of Operation has no roots.
:raises AssertionError: If the successor of an operation is not in the Graph of Operations.
"""
self.logger.debug("Checking that the program is in a valid state")
assert self.graph.roots is not None, "The operations graph has no root"
self.logger.debug("The program is in a valid state")
execution_queue = [
operation
for operation in self.graph.operations
if operation.can_be_executed()
]
while len(execution_queue) > 0:
current_operation = execution_queue.pop(0)
self.logger.info("Executing operation %s", current_operation.operation_type)
current_operation.execute(
self.lm, self.prompter, self.parser, **self.problem_parameters
)
self.logger.info("Operation %s executed", current_operation.operation_type)
for operation in current_operation.successors:
assert (
operation in self.graph.operations
), "The successor of an operation is not in the operations graph"
if operation.can_be_executed():
execution_queue.append(operation)
self.logger.info("All operations executed")
self.run_executed = True
def get_final_thoughts(self) -> List[List[Thought]]:
"""
Retrieve the final thoughts after all operations have been executed.
:return: List of thoughts for each operation in the graph's leaves.
:rtype: List[List[Thought]]
:raises AssertionError: If the `run` method hasn't been executed yet.
"""
assert self.run_executed, "The run method has not been executed"
return [operation.get_thoughts() for operation in self.graph.leaves]
def output_graph(self, path: str) -> None:
"""
Serialize the state and results of the operations graph to a JSON file.
:param path: The path to the output file.
:type path: str
"""
output = []
for operation in self.graph.operations:
operation_serialized = {
"operation": operation.operation_type.name,
"thoughts": [thought.state for thought in operation.get_thoughts()],
}
if any([thought.scored for thought in operation.get_thoughts()]):
operation_serialized["scored"] = [
thought.scored for thought in operation.get_thoughts()
]
operation_serialized["scores"] = [
thought.score for thought in operation.get_thoughts()
]
if any([thought.validated for thought in operation.get_thoughts()]):
operation_serialized["validated"] = [
thought.validated for thought in operation.get_thoughts()
]
operation_serialized["validity"] = [
thought.valid for thought in operation.get_thoughts()
]
if any(
[
thought.compared_to_ground_truth
for thought in operation.get_thoughts()
]
):
operation_serialized["compared_to_ground_truth"] = [
thought.compared_to_ground_truth
for thought in operation.get_thoughts()
]
operation_serialized["problem_solved"] = [
thought.solved for thought in operation.get_thoughts()
]
output.append(operation_serialized)
output.append(
{
"prompt_tokens": self.lm.prompt_tokens,
"completion_tokens": self.lm.completion_tokens,
"cost": self.lm.cost,
}
)
with open(path, "w") as file:
file.write(json.dumps(output, indent=2))