93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
# Copyright (c) 2023 ETH Zurich.
|
|
# All rights reserved.
|
|
#
|
|
# Use of this source code is governed by a BSD-style license that can be
|
|
# found in the LICENSE file.
|
|
#
|
|
# main author: Robert Gerstenberger
|
|
|
|
import csv
|
|
import numpy as np
|
|
|
|
|
|
def scramble(array: np.ndarray, rng: np.random.Generator) -> None:
|
|
"""
|
|
Helper function to change the order of the elements in an array randomly.
|
|
|
|
:param array: Array to be scrambled.
|
|
:type: numpy.ndarray
|
|
:param rng: Random number generator.
|
|
:type rng: numpy.random.Generator
|
|
"""
|
|
|
|
size = array.shape[0]
|
|
|
|
index_array = rng.integers(0, size, size)
|
|
|
|
for i in range(size):
|
|
temp = array[i]
|
|
array[i] = array[index_array[i]]
|
|
array[index_array[i]] = temp
|
|
|
|
|
|
if __name__ == "__main__":
|
|
"""
|
|
Input(u) : Set size.
|
|
Input(v) : Range of the integer numbers in the sets: 0..v (exclusive)
|
|
Input(w) : Seed for the random number generator.
|
|
Input(x) : Number of samples to be generated.
|
|
Input(y) : Filename for the output CSV file.
|
|
Output(z) : Input sets and intersected set written a file in the CSV format.
|
|
File contains the sample ID, input set 1, input set 2,
|
|
intersection set.
|
|
"""
|
|
|
|
set_size = 32 # size of the generated sets
|
|
int_value_ubound = 64 # (exclusive) upper limit of generated numbers
|
|
seed = 42 # seed of the random number generator
|
|
num_sample = 100 # number of samples
|
|
filename = "set_intersection_032.csv" # output filename
|
|
|
|
assert 2 * set_size <= int_value_ubound
|
|
|
|
rng = np.random.default_rng(seed)
|
|
|
|
intersection_sizes = rng.integers(set_size // 4, 3 * set_size // 4, num_sample)
|
|
|
|
np.set_printoptions(
|
|
linewidth=np.inf
|
|
) # no wrapping in the array fields in the output file
|
|
|
|
with open(filename, "w") as f:
|
|
fieldnames = ["ID", "SET1", "SET2", "INTERSECTION"]
|
|
writer = csv.DictWriter(f, delimiter=",", fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
|
|
for i in range(num_sample):
|
|
intersection_size = intersection_sizes[i]
|
|
|
|
full_set = np.arange(0, int_value_ubound, dtype=np.int16)
|
|
|
|
scramble(full_set, rng)
|
|
|
|
intersection = full_set[:intersection_size].copy()
|
|
|
|
sorted_intersection = np.sort(intersection)
|
|
|
|
set1 = full_set[:set_size].copy()
|
|
set2 = np.concatenate(
|
|
[intersection, full_set[set_size : 2 * set_size - intersection_size]]
|
|
)
|
|
|
|
scramble(set1, rng)
|
|
scramble(set2, rng)
|
|
|
|
writer.writerow(
|
|
{
|
|
"ID": i,
|
|
"SET1": set1.tolist(),
|
|
"SET2": set2.tolist(),
|
|
"INTERSECTION": sorted_intersection.tolist(),
|
|
}
|
|
)
|