pradachan's picture
Upload folder using huggingface_hub
f71c233 verified
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import json
import os
import os.path as osp
# LOAD FINAL RESULTS:
datasets = ["x_div_y", "x_minus_y", "x_plus_y", "permutation"]
folders = os.listdir("./")
final_results = {}
results_info = {}
for folder in folders:
if folder.startswith("run") and osp.isdir(folder):
with open(osp.join(folder, "final_info.json"), "r") as f:
final_results[folder] = json.load(f)
results_dict = np.load(
osp.join(folder, "all_results.npy"), allow_pickle=True
).item()
print(results_dict.keys())
run_info = {}
for dataset in datasets:
run_info[dataset] = {}
val_losses = []
train_losses = []
val_accs = []
train_accs = []
for k in results_dict.keys():
if dataset in k and "val_info" in k:
run_info[dataset]["step"] = [
info["step"] for info in results_dict[k]
]
val_losses.append([info["val_loss"] for info in results_dict[k]])
val_accs.append([info["val_accuracy"] for info in results_dict[k]])
if dataset in k and "train_info" in k:
train_losses.append(
[info["train_loss"] for info in results_dict[k]]
)
train_accs.append(
[info["train_accuracy"] for info in results_dict[k]]
)
mean_val_losses = np.mean(val_losses, axis=0)
mean_train_losses = np.mean(train_losses, axis=0)
mean_val_accs = np.mean(val_accs, axis=0)
mean_train_accs = np.mean(train_accs, axis=0)
if len(val_losses) > 0:
sterr_val_losses = np.std(val_losses, axis=0) / np.sqrt(
len(val_losses)
)
stderr_train_losses = np.std(train_losses, axis=0) / np.sqrt(
len(train_losses)
)
sterr_val_accs = np.std(val_accs, axis=0) / np.sqrt(len(val_accs))
stderr_train_accs = np.std(train_accs, axis=0) / np.sqrt(
len(train_accs)
)
else:
sterr_val_losses = np.zeros_like(mean_val_losses)
stderr_train_losses = np.zeros_like(mean_train_losses)
sterr_val_accs = np.zeros_like(mean_val_accs)
stderr_train_accs = np.zeros_like(mean_train_accs)
run_info[dataset]["val_loss"] = mean_val_losses
run_info[dataset]["train_loss"] = mean_train_losses
run_info[dataset]["val_loss_sterr"] = sterr_val_losses
run_info[dataset]["train_loss_sterr"] = stderr_train_losses
run_info[dataset]["val_acc"] = mean_val_accs
run_info[dataset]["train_acc"] = mean_train_accs
run_info[dataset]["val_acc_sterr"] = sterr_val_accs
run_info[dataset]["train_acc_sterr"] = stderr_train_accs
results_info[folder] = run_info
# CREATE LEGEND -- ADD RUNS HERE THAT WILL BE PLOTTED
labels = {
"run_0": "Baseline",
"run_1": "Operand Reversal",
"run_2": "Negation",
"run_3": "Combined (15%)",
"run_4": "Combined (30%)",
"run_5": "Combined (20%)",
}
# Create a programmatic color palette
def generate_color_palette(n):
cmap = plt.get_cmap("tab20")
return [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, n)]
# Get the list of runs and generate the color palette
runs = list(labels.keys())
colors = generate_color_palette(len(runs))
# Plot 1: Line plot of validation accuracy for each dataset across the runs
for dataset in datasets:
plt.figure(figsize=(12, 6))
for i, run in enumerate(runs):
iters = results_info[run][dataset]["step"]
mean = results_info[run][dataset]["val_acc"]
sterr = results_info[run][dataset]["val_acc_sterr"]
plt.plot(iters, mean, label=labels[run], color=colors[i])
plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2)
plt.title(f"Validation Accuracy Across Runs for {dataset} Dataset")
plt.xlabel("Update Steps")
plt.ylabel("Validation Accuracy")
plt.legend()
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.ylim(0, 1.05)
plt.tight_layout()
plt.savefig(f"val_acc_{dataset}.png")
plt.close()
# Plot 2: Bar plot of steps to 99% validation accuracy for each dataset and run
step_val_acc_99 = {dataset: [] for dataset in datasets}
for dataset in datasets:
for run in runs:
step = final_results[run][dataset]["means"].get("step_val_acc_99_mean")
step_val_acc_99[dataset].append(step if step is not None else 7500)
plt.figure(figsize=(12, 6))
x = np.arange(len(datasets))
width = 0.15
for i, run in enumerate(runs):
plt.bar(x + i*width, [step_val_acc_99[dataset][i] for dataset in datasets], width, label=labels[run])
plt.xlabel('Datasets')
plt.ylabel('Steps to 99% Validation Accuracy')
plt.title('Steps to 99% Validation Accuracy Across Runs and Datasets')
plt.xticks(x + width * 2.5, datasets)
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig('steps_to_99_acc.png')
plt.close()
# Plot 3: Bar plot of final validation accuracy for each dataset and run
final_val_acc = {dataset: [] for dataset in datasets}
for dataset in datasets:
for run in runs:
acc = final_results[run][dataset]["means"]["final_val_acc_mean"]
final_val_acc[dataset].append(acc)
plt.figure(figsize=(12, 6))
x = np.arange(len(datasets))
width = 0.15
for i, run in enumerate(runs):
plt.bar(x + i*width, [final_val_acc[dataset][i] for dataset in datasets], width, label=labels[run])
plt.xlabel('Datasets')
plt.ylabel('Final Validation Accuracy')
plt.title('Final Validation Accuracy Across Runs and Datasets')
plt.xticks(x + width * 2.5, datasets)
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig('final_val_acc.png')
plt.close()