Spaces:
Running
Running
import matplotlib.pyplot as plt | |
import numpy as np | |
import matplotlib.cm as cm | |
from utils.util_data import load_dataset | |
def get_cmap(num_colors): | |
if num_colors <= 10: | |
cm_name = "tab10" | |
elif num_colors <= 20: | |
cm_name = "tab20" | |
else: | |
assert False | |
return cm.get_cmap(cm_name) | |
def analyze_dataset(dataset_path, output_dir): | |
dataset = load_dataset(dataset_path) | |
#----------------------------- | |
# Stepwise frequency analysis | |
#----------------------------- | |
max_steps = len(dataset[0][0]) # num_nodes | |
num_labels = 2 | |
freq = [[] for _ in range(num_labels)] | |
weights = [[] for _ in range(num_labels)] | |
for instance in dataset: | |
labels = instance[-1] | |
for step, label in labels: | |
freq[label].append(step) | |
# visualize histogram | |
fig = plt.figure(figsize=(10, 10)) | |
binwidth = 1 | |
bins = np.arange(0, max_steps + binwidth, binwidth) | |
cmap = get_cmap(num_labels) | |
for i in range(len(weights)): | |
weights[i] = np.ones(len(freq[i])) / len(dataset) | |
plt.hist(freq[i], bins=bins, alpha=0.5, weights=weights[i], ec=cmap(i), color=cmap(i), label="prioritizing tour length", align="left") | |
plt.xlabel("Steps") | |
plt.ylabel("Frequency (density)") | |
if max_steps <= 20: | |
plt.xticks(np.arange(0, max_steps+1, 1)) | |
plt.title(f"# of samples = {len(dataset)}\n# of nodes = {max_steps}") | |
plt.legend() | |
plt.savefig(f"{output_dir}/hist.png", dpi=150, bbox_inches="tight") | |
#----------------------------- | |
# Overall ratio of each class | |
#----------------------------- | |
total = np.sum([len(freq[i]) for i in range(num_labels)]) | |
ratio = np.array([len(freq[i]) for i in range(num_labels)]) | |
ratio = ratio / total | |
with open(f"{output_dir}/ratio.dat", "w") as f: | |
for i in range(len(ratio)): | |
print(f"label{i}, {ratio[i]}", file=f) | |
if __name__ == "__main__": | |
import argparse | |
import os | |
parser = argparse.ArgumentParser(description='') | |
parser.add_argument("--dataset_path", type=str, required=True) | |
parser.add_argument("--output_dir", type=str, default=None) | |
args = parser.parse_args() | |
if args.output_dir is None: | |
dataset_dir = os.path.split(args.dataset_path)[0] | |
output_dir = dataset_dir | |
else: | |
output_dir = args.output_dir | |
output_dir += "/analysis" | |
os.makedirs(output_dir, exist_ok=True) | |
analyze_dataset(args.dataset_path, output_dir) |