79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
# Copyright (c) 2023 ETH Zurich.
|
|
# All rights reserved.
|
|
#
|
|
# Use of this source code is governed by a BSD-style license that can be
|
|
# found in the LICENSE file.
|
|
#
|
|
# main author: Nils Blach
|
|
|
|
from typing import Dict, List
|
|
|
|
|
|
def string_to_list(string: str) -> List[int]:
|
|
"""
|
|
Helper function to convert a list encoded inside a string into a Python
|
|
list object of string elements.
|
|
|
|
:param string: Input string containing a list.
|
|
:type string: str
|
|
:return: List of string elements.
|
|
:rtype: List[str]
|
|
:raise AssertionError: If input string does not contain a list.
|
|
"""
|
|
|
|
assert string[0] == "[" and string[-1] == "]", "String is not a list."
|
|
return [int(num) for num in string[1:-1].split(",")]
|
|
|
|
|
|
def test_sorting(state: Dict) -> bool:
|
|
"""
|
|
Function to test whether the final solution matches ground truth.
|
|
|
|
:param state: Thought state that represents the final solution.
|
|
:type state: Dict
|
|
:return: Returns whether the solution matches the ground truth.
|
|
:rtype: bool
|
|
"""
|
|
|
|
try:
|
|
correct_list = sorted(string_to_list(state["original"]))
|
|
sorted_list = string_to_list(state["current"])
|
|
return sorted_list == correct_list
|
|
except:
|
|
return False
|
|
|
|
|
|
def num_errors(state: Dict) -> float:
|
|
"""
|
|
Function to locally count the number of errors that serves as a score.
|
|
|
|
:param state: Thought state to be scored.
|
|
:type state: Dict
|
|
:return: Number of errors.
|
|
:rtype: float
|
|
"""
|
|
|
|
try:
|
|
unsorted_list = state["original"]
|
|
if (
|
|
"unsorted_sublist" in state
|
|
and state["unsorted_sublist"] != ""
|
|
and state["unsorted_sublist"] is not None
|
|
and len(state["unsorted_sublist"]) < len(unsorted_list) - 5
|
|
):
|
|
unsorted_list = state["unsorted_sublist"]
|
|
correct_list = sorted(string_to_list(unsorted_list))
|
|
current_list = string_to_list(state["current"])
|
|
num_errors = 0
|
|
for i in range(10):
|
|
num_errors += abs(
|
|
sum([1 for num in current_list if num == i])
|
|
- sum([1 for num in correct_list if num == i])
|
|
)
|
|
num_errors += sum(
|
|
[1 for num1, num2 in zip(current_list, current_list[1:]) if num1 > num2]
|
|
)
|
|
return num_errors
|
|
except:
|
|
return 300
|