100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
# Copyright (c) 2023 ETH Zurich.
|
|
# All rights reserved.
|
|
#
|
|
# Use of this source code is governed by a BSD-style license that can be
|
|
# found in the LICENSE file.
|
|
#
|
|
# The source code is adapted from the sorting source code written by
|
|
# Nils Blach.
|
|
#
|
|
# main author: Robert Gerstenberger
|
|
|
|
from typing import Dict, List, Set
|
|
|
|
|
|
def string_to_list(string: str) -> List[int]:
|
|
"""
|
|
Helper function to convert a list encoded inside a string into a Python
|
|
list object of integer elements.
|
|
|
|
:param string: Input string containing a list.
|
|
:type string: str
|
|
:return: List of integer elements.
|
|
:rtype: List[int]
|
|
:raise AssertionError: If input string does not contain a list.
|
|
"""
|
|
|
|
assert string[0] == "[" and string[-1] == "]", "String is not a list."
|
|
return [int(num) for num in string[1:-1].split(",")]
|
|
|
|
|
|
def string_to_set(string: str) -> Set[int]:
|
|
"""
|
|
Helper function to convert a list encoded inside a string into a Python
|
|
set object of integer elements.
|
|
|
|
:param string: Input string containing a list.
|
|
:type string: str
|
|
:return: Set of integer elements.
|
|
:rtype: Set[int]
|
|
:raise AssertionError: If input string does not contain a list.
|
|
"""
|
|
|
|
assert string[0] == "[" and string[-1] == "]", "String is not a list."
|
|
return {int(num) for num in string[1:-1].split(",")}
|
|
|
|
|
|
def test_set_intersection(state: Dict) -> bool:
|
|
"""
|
|
Function to test whether the final solution matches ground truth.
|
|
|
|
:param state: Thought state that represents the final solution.
|
|
:type state: Dict
|
|
:return: Returns whether the solution matches the ground truth.
|
|
:rtype: bool
|
|
"""
|
|
|
|
# convert string to list
|
|
try:
|
|
correct_list = string_to_list(state["result"])
|
|
sorted_list = sorted(string_to_list(state["current"]))
|
|
return sorted_list == correct_list
|
|
except:
|
|
return False
|
|
|
|
|
|
def num_errors(state: Dict) -> float:
|
|
"""
|
|
Function to locally count the number of errors that serves as a score.
|
|
|
|
:param state: Thought state to be scored.
|
|
:type state: Dict
|
|
:return: Number of errors.
|
|
:rtype: float
|
|
"""
|
|
|
|
try:
|
|
set1 = string_to_set(state["set1"])
|
|
set2 = string_to_set(state["set2"])
|
|
if "subset" in state and state["subset"] != "" and state["subset"] is not None:
|
|
set2 = string_to_set(state["subset"])
|
|
common = sorted(list(set1 & set2))
|
|
llm_solution = sorted(string_to_list(state["current"]))
|
|
num_errors = 0
|
|
common_idx = 0
|
|
llm_idx = 0
|
|
while common_idx < len(common) and llm_idx < len(llm_solution):
|
|
if common[common_idx] == llm_solution[llm_idx]:
|
|
common_idx += 1
|
|
llm_idx += 1
|
|
elif common[common_idx] < llm_solution[llm_idx]:
|
|
common_idx += 1
|
|
num_errors += 1
|
|
elif common[common_idx] > llm_solution[llm_idx]:
|
|
llm_idx += 1
|
|
num_errors += 1
|
|
num_errors += len(common) - common_idx + len(llm_solution) - llm_idx
|
|
return num_errors
|
|
except:
|
|
return 1000
|