2023-08-21 03:33:46 +02:00

79 lines
2.3 KiB
Python

# Copyright (c) 2023 ETH Zurich.
# All rights reserved.
#
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
#
# main author: Nils Blach
from typing import Dict, List
def string_to_list(string: str) -> List[int]:
"""
Helper function to convert a list encoded inside a string into a Python
list object of string elements.
:param string: Input string containing a list.
:type string: str
:return: List of string elements.
:rtype: List[str]
:raise AssertionError: If input string does not contain a list.
"""
assert string[0] == "[" and string[-1] == "]", "String is not a list."
return [int(num) for num in string[1:-1].split(",")]
def test_sorting(state: Dict) -> bool:
"""
Function to test whether the final solution matches ground truth.
:param state: Thought state that represents the final solution.
:type state: Dict
:return: Returns whether the solution matches the ground truth.
:rtype: bool
"""
try:
correct_list = sorted(string_to_list(state["original"]))
sorted_list = string_to_list(state["current"])
return sorted_list == correct_list
except:
return False
def num_errors(state: Dict) -> float:
"""
Function to locally count the number of errors that serves as a score.
:param state: Thought state to be scored.
:type state: Dict
:return: Number of errors.
:rtype: float
"""
try:
unsorted_list = state["original"]
if (
"unsorted_sublist" in state
and state["unsorted_sublist"] != ""
and state["unsorted_sublist"] is not None
and len(state["unsorted_sublist"]) < len(unsorted_list) - 5
):
unsorted_list = state["unsorted_sublist"]
correct_list = sorted(string_to_list(unsorted_list))
current_list = string_to_list(state["current"])
num_errors = 0
for i in range(10):
num_errors += abs(
sum([1 for num in current_list if num == i])
- sum([1 for num in correct_list if num == i])
)
num_errors += sum(
[1 for num1, num2 in zip(current_list, current_list[1:]) if num1 > num2]
)
return num_errors
except:
return 300