Spaces:
Runtime error
Runtime error
| import seaborn as sns | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| font = { | |
| "family": "normal", | |
| "size": 22, | |
| } | |
| matplotlib.rc("font", **font) | |
| sns.set(rc={"font.family": "Times New Roman"}) | |
| sns.set(style="whitegrid") | |
| sns.set(font_scale=3, style="whitegrid") | |
| # Sample data for plotting | |
| categories = ["Scratch", "Passive Pre-Train", "Pre-Train", "Pre-Train (Large)"] | |
| values = [1.0, 1.0, 1.0, 1.0] | |
| # Define custom colors for the bars | |
| colors = ["#4c72b0", "#55a868", "#c44e52", "#8172b2"] # Adjust as needed | |
| plt.figure(figsize=(14, 12)) | |
| ax = sns.barplot( | |
| x=categories, y=values, alpha=0.9, palette=colors, edgecolor="black" | |
| ) | |
| for container in ax.containers: | |
| ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.2f") | |
| # Adding title and labels | |
| plt.xlabel("Setting", fontsize=40) | |
| plt.ylabel("Validation Perplexity", fontsize=40) | |
| plt.xticks(fontsize=30) | |
| ax.tick_params(axis='x', rotation=15) | |
| plt.yticks(fontsize=30) | |
| plt.legend(fontsize="small", title_fontsize="small", loc="lower left") | |
| # Remove the borders | |
| sns.despine(left=True, bottom=True) | |
| # Display the plot | |
| plt.tight_layout() | |
| plt.savefig(f"output/model_ablation.png", dpi=300) # Save the figure in high resolution | |
| plt.show() | |