graph-of-thoughts/examples/set_intersection/dataset_gen_intersection.py
2023-08-21 03:33:46 +02:00

93 lines
2.8 KiB
Python

# Copyright (c) 2023 ETH Zurich.
# All rights reserved.
#
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
#
# main author: Robert Gerstenberger
import csv
import numpy as np
def scramble(array: np.ndarray, rng: np.random.Generator) -> None:
"""
Helper function to change the order of the elements in an array randomly.
:param array: Array to be scrambled.
:type: numpy.ndarray
:param rng: Random number generator.
:type rng: numpy.random.Generator
"""
size = array.shape[0]
index_array = rng.integers(0, size, size)
for i in range(size):
temp = array[i]
array[i] = array[index_array[i]]
array[index_array[i]] = temp
if __name__ == "__main__":
"""
Input(u) : Set size.
Input(v) : Range of the integer numbers in the sets: 0..v (exclusive)
Input(w) : Seed for the random number generator.
Input(x) : Number of samples to be generated.
Input(y) : Filename for the output CSV file.
Output(z) : Input sets and intersected set written a file in the CSV format.
File contains the sample ID, input set 1, input set 2,
intersection set.
"""
set_size = 32 # size of the generated sets
int_value_ubound = 64 # (exclusive) upper limit of generated numbers
seed = 42 # seed of the random number generator
num_sample = 100 # number of samples
filename = "set_intersection_032.csv" # output filename
assert 2 * set_size <= int_value_ubound
rng = np.random.default_rng(seed)
intersection_sizes = rng.integers(set_size // 4, 3 * set_size // 4, num_sample)
np.set_printoptions(
linewidth=np.inf
) # no wrapping in the array fields in the output file
with open(filename, "w") as f:
fieldnames = ["ID", "SET1", "SET2", "INTERSECTION"]
writer = csv.DictWriter(f, delimiter=",", fieldnames=fieldnames)
writer.writeheader()
for i in range(num_sample):
intersection_size = intersection_sizes[i]
full_set = np.arange(0, int_value_ubound, dtype=np.int16)
scramble(full_set, rng)
intersection = full_set[:intersection_size].copy()
sorted_intersection = np.sort(intersection)
set1 = full_set[:set_size].copy()
set2 = np.concatenate(
[intersection, full_set[set_size : 2 * set_size - intersection_size]]
)
scramble(set1, rng)
scramble(set2, rng)
writer.writerow(
{
"ID": i,
"SET1": set1.tolist(),
"SET2": set2.tolist(),
"INTERSECTION": sorted_intersection.tolist(),
}
)