2023-08-21 03:33:46 +02:00

100 lines
3.1 KiB
Python

# Copyright (c) 2023 ETH Zurich.
# All rights reserved.
#
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
#
# The source code is adapted from the sorting source code written by
# Nils Blach.
#
# main author: Robert Gerstenberger
from typing import Dict, List, Set
def string_to_list(string: str) -> List[int]:
"""
Helper function to convert a list encoded inside a string into a Python
list object of integer elements.
:param string: Input string containing a list.
:type string: str
:return: List of integer elements.
:rtype: List[int]
:raise AssertionError: If input string does not contain a list.
"""
assert string[0] == "[" and string[-1] == "]", "String is not a list."
return [int(num) for num in string[1:-1].split(",")]
def string_to_set(string: str) -> Set[int]:
"""
Helper function to convert a list encoded inside a string into a Python
set object of integer elements.
:param string: Input string containing a list.
:type string: str
:return: Set of integer elements.
:rtype: Set[int]
:raise AssertionError: If input string does not contain a list.
"""
assert string[0] == "[" and string[-1] == "]", "String is not a list."
return {int(num) for num in string[1:-1].split(",")}
def test_set_intersection(state: Dict) -> bool:
"""
Function to test whether the final solution matches ground truth.
:param state: Thought state that represents the final solution.
:type state: Dict
:return: Returns whether the solution matches the ground truth.
:rtype: bool
"""
# convert string to list
try:
correct_list = string_to_list(state["result"])
sorted_list = sorted(string_to_list(state["current"]))
return sorted_list == correct_list
except:
return False
def num_errors(state: Dict) -> float:
"""
Function to locally count the number of errors that serves as a score.
:param state: Thought state to be scored.
:type state: Dict
:return: Number of errors.
:rtype: float
"""
try:
set1 = string_to_set(state["set1"])
set2 = string_to_set(state["set2"])
if "subset" in state and state["subset"] != "" and state["subset"] is not None:
set2 = string_to_set(state["subset"])
common = sorted(list(set1 & set2))
llm_solution = sorted(string_to_list(state["current"]))
num_errors = 0
common_idx = 0
llm_idx = 0
while common_idx < len(common) and llm_idx < len(llm_solution):
if common[common_idx] == llm_solution[llm_idx]:
common_idx += 1
llm_idx += 1
elif common[common_idx] < llm_solution[llm_idx]:
common_idx += 1
num_errors += 1
elif common[common_idx] > llm_solution[llm_idx]:
llm_idx += 1
num_errors += 1
num_errors += len(common) - common_idx + len(llm_solution) - llm_idx
return num_errors
except:
return 1000