{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0625f0a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "from ase.db import connect\n",
    "\n",
    "random.seed(0)\n",
    "\n",
    "DATA_DIR = Path(\".\")\n",
    "\n",
    "db = connect(DATA_DIR / \"c2db.db\")\n",
    "random_indices = random.sample(range(1, len(db) + 1), 1000)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "005708b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "import pandas as pd\n",
    "import phonopy\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from mlip_arena.models import MLIPEnum\n",
    "\n",
    "for row, model in tqdm(\n",
    "    itertools.product(db.select(filter=lambda r: r[\"id\"] in random_indices), MLIPEnum)\n",
    "):\n",
    "    uid = row[\"uid\"]\n",
    "\n",
    "    if Path(f\"{model.name}.parquet\").exists():\n",
    "        df = pd.read_parquet(f\"{model.name}.parquet\")\n",
    "        if uid in df[\"uid\"].unique():\n",
    "            continue\n",
    "    else:\n",
    "        df = pd.DataFrame(columns=[\"model\", \"uid\", \"eigenvalues\", \"frequencies\"])\n",
    "\n",
    "    try:\n",
    "        path = Path(model.name) / uid\n",
    "        phonon = phonopy.load(path / \"phonopy.yaml\")\n",
    "        frequencies = phonon.get_frequencies(q=(0, 0, 0))\n",
    "\n",
    "        data = np.load(path / \"elastic.npz\")\n",
    "\n",
    "        eigenvalues = data[\"eigenvalues\"]\n",
    "\n",
    "        new_row = pd.DataFrame(\n",
    "            [\n",
    "                {\n",
    "                    \"model\": model.name,\n",
    "                    \"uid\": uid,\n",
    "                    \"eigenvalues\": eigenvalues,\n",
    "                    \"frequencies\": frequencies,\n",
    "                }\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        df = pd.concat([df, new_row], ignore_index=True)\n",
    "        df.drop_duplicates(subset=[\"model\", \"uid\"], keep=\"last\", inplace=True)\n",
    "\n",
    "        df.to_parquet(f\"{model.name}.parquet\", index=False)\n",
    "    except Exception:\n",
    "        pass\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b8d87638",
   "metadata": {},
   "outputs": [],
   "source": [
    "uids = []\n",
    "stabilities = []\n",
    "for row in db.select(filter=lambda r: r[\"id\"] in random_indices):\n",
    "    stable = row.key_value_pairs[\"dyn_stab\"]\n",
    "    if stable.lower() == \"unknown\":\n",
    "        stable = None\n",
    "    else:\n",
    "        stable = True if stable.lower() == \"yes\" else False\n",
    "    uids.append(row.key_value_pairs[\"uid\"])\n",
    "    stabilities.append(stable)\n",
    "\n",
    "\n",
    "stabilities = np.array(stabilities)\n",
    "\n",
    "(stabilities == True).sum(), (stabilities == False).sum(), (stabilities == None).sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3c516a7",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "0052d0ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn.metrics import (\n",
    "    ConfusionMatrixDisplay,\n",
    "    classification_report,\n",
    "    confusion_matrix,\n",
    ")\n",
    "\n",
    "from mlip_arena.models import MLIPEnum\n",
    "\n",
    "thres = -1e-7\n",
    "\n",
    "select_models = [\n",
    "    \"ALIGNN\",\n",
    "    \"CHGNet\",\n",
    "    \"M3GNet\",\n",
    "    \"MACE-MP(M)\",\n",
    "    \"MACE-MPA\",\n",
    "    \"MatterSim\",\n",
    "    \"ORBv2\",\n",
    "    \"SevenNet\",\n",
    "]\n",
    "\n",
    "with plt.style.context(\"default\"):\n",
    "\n",
    "    SMALL_SIZE = 8\n",
    "    MEDIUM_SIZE = 10\n",
    "    BIGGER_SIZE = 12\n",
    "    \n",
    "    plt.rcParams.update(\n",
    "        {\n",
    "            \"font.size\": SMALL_SIZE,\n",
    "            \"axes.titlesize\": MEDIUM_SIZE,\n",
    "            \"axes.labelsize\": MEDIUM_SIZE,\n",
    "            \"xtick.labelsize\": MEDIUM_SIZE,\n",
    "            \"ytick.labelsize\": MEDIUM_SIZE,\n",
    "            \"legend.fontsize\": SMALL_SIZE,\n",
    "            \"figure.titlesize\": BIGGER_SIZE,\n",
    "        }\n",
    "    )\n",
    "\n",
    "    fig, axs = plt.subplots(\n",
    "        nrows=int(np.ceil(len(MLIPEnum) / 4)),\n",
    "        ncols=4,\n",
    "        figsize=(6, 3 * int(np.ceil(len(select_models) / 4))),\n",
    "        sharey=True,\n",
    "        sharex=True,\n",
    "        layout=\"constrained\",\n",
    "    )\n",
    "    axs = axs.flatten()\n",
    "    plot_idx = 0\n",
    "\n",
    "    for model in MLIPEnum:\n",
    "        fpath = DATA_DIR / f\"{model.name}.parquet\"\n",
    "        if not fpath.exists():\n",
    "            continue\n",
    "\n",
    "        if model.name not in select_models:\n",
    "            continue\n",
    "\n",
    "        df = pd.read_parquet(fpath)\n",
    "        df[\"eigval_min\"] = df[\"eigenvalues\"].apply(\n",
    "            lambda x: x.min() if np.isreal(x).all() else thres\n",
    "        )\n",
    "        df[\"freq_min\"] = df[\"frequencies\"].apply(\n",
    "            lambda x: x.min() if np.isreal(x).all() else thres\n",
    "        )\n",
    "        df[\"dyn_stab\"] = ~np.logical_or(\n",
    "            df[\"eigval_min\"] < thres, df[\"freq_min\"] < thres\n",
    "        )\n",
    "\n",
    "        arg = np.argsort(uids)\n",
    "        uids_sorted = np.array(uids)[arg]\n",
    "        stabilities_sorted = stabilities[arg]\n",
    "\n",
    "        sorted_df = (\n",
    "            df[df[\"uid\"].isin(uids_sorted)].set_index(\"uid\").reindex(uids_sorted)\n",
    "        )\n",
    "        mask = ~(stabilities_sorted == None)\n",
    "\n",
    "        y_true = stabilities_sorted[mask].astype(\"int\")\n",
    "        y_pred = sorted_df[\"dyn_stab\"][mask].fillna(-1).astype(\"int\")\n",
    "        cm = confusion_matrix(y_true, y_pred, labels=[1, 0, -1])\n",
    "\n",
    "        ax = axs[plot_idx]\n",
    "        ConfusionMatrixDisplay(\n",
    "            cm, display_labels=[\"stable\", \"unstable\", \"missing\"]\n",
    "        ).plot(ax=ax, cmap=\"Blues\", colorbar=False)\n",
    "\n",
    "        ax.set_title(model.name)\n",
    "        ax.set_xlabel(\"Predicted\")\n",
    "        ax.set_ylabel(\"True\")\n",
    "        ax.set_xticks([0, 1, 2])\n",
    "        ax.set_xticklabels([\"stable\", \"unstable\", \"missing\"])\n",
    "        ax.set_yticks([0, 1, 2])\n",
    "        ax.set_yticklabels([\"stable\", \"unstable\", \"missing\"])\n",
    "\n",
    "        plot_idx += 1\n",
    "\n",
    "    # Hide unused subplots\n",
    "    for i in range(plot_idx, len(axs)):\n",
    "        fig.delaxes(axs[i])\n",
    "\n",
    "    # plt.tight_layout()\n",
    "    plt.savefig(\"c2db-confusion_matrices.pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "573b3c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "from mlip_arena.models import MLIPEnum\n",
    "\n",
    "thres = -1e-7\n",
    "\n",
    "summary_df = pd.DataFrame(columns=[\"Model\", \"Stable F1\", \"Unstable F1\", \"Weighted F1\"])\n",
    "\n",
    "for model in MLIPEnum:\n",
    "    fpath = DATA_DIR / f\"{model.name}.parquet\"\n",
    "\n",
    "    if not fpath.exists() or model.name not in select_models:\n",
    "        # print(f\"File {fpath} does not exist\")\n",
    "        continue\n",
    "    df = pd.read_parquet(fpath)\n",
    "\n",
    "    df[\"eigval_min\"] = df[\"eigenvalues\"].apply(\n",
    "        lambda x: x.min() if np.isreal(x).all() else thres\n",
    "    )\n",
    "    df[\"freq_min\"] = df[\"frequencies\"].apply(\n",
    "        lambda x: x.min() if np.isreal(x).all() else thres\n",
    "    )\n",
    "    df[\"dyn_stab\"] = ~np.logical_or(df[\"eigval_min\"] < thres, df[\"freq_min\"] < thres)\n",
    "\n",
    "    arg = np.argsort(uids)\n",
    "    uids = np.array(uids)[arg]\n",
    "    stabilities = stabilities[arg]\n",
    "\n",
    "    sorted_df = df[df[\"uid\"].isin(uids)].sort_values(by=\"uid\")\n",
    "\n",
    "    # sorted_df = sorted_df.reindex(uids).reset_index()\n",
    "    sorted_df = sorted_df.set_index(\"uid\").reindex(uids)  # .loc[uids].reset_index()\n",
    "\n",
    "    sorted_df = sorted_df.loc[uids]\n",
    "    # mask = ~np.logical_or(sorted_df['dyn_stab'].isna().values, stabilities == None)\n",
    "    mask = ~(stabilities == None)\n",
    "\n",
    "    y_true = stabilities[mask].astype(\"int\")\n",
    "    y_pred = sorted_df[\"dyn_stab\"][mask].fillna(-1).astype(\"int\")\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=[1, 0, -1])\n",
    "    # print(model)\n",
    "    # print(cm)\n",
    "    # print(classification_report(y_true, y_pred, labels=[1, 0], target_names=['stable', 'unstable'], digits=3, output_dict=False))\n",
    "\n",
    "    report = classification_report(\n",
    "        y_true,\n",
    "        y_pred,\n",
    "        labels=[1, 0],\n",
    "        target_names=[\"stable\", \"unstable\"],\n",
    "        digits=3,\n",
    "        output_dict=True,\n",
    "    )\n",
    "\n",
    "    summary_df = pd.concat(\n",
    "        [\n",
    "            summary_df,\n",
    "            pd.DataFrame(\n",
    "                [\n",
    "                    {\n",
    "                        \"Model\": model.name,\n",
    "                        \"Stable F1\": report[\"stable\"][\"f1-score\"],\n",
    "                        \"Unstable F1\": report[\"unstable\"][\"f1-score\"],\n",
    "                        \"Macro F1\": report[\"macro avg\"][\"f1-score\"],\n",
    "                        # 'Micro F1': report['micro avg']['f1-score'],\n",
    "                        \"Weighted F1\": report[\"weighted avg\"][\"f1-score\"],\n",
    "                    }\n",
    "                ]\n",
    "            ),\n",
    "        ],\n",
    "        ignore_index=True,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "df660870",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary_df = summary_df.sort_values(by=[\"Macro F1\", \"Weighted F1\"], ascending=False)\n",
    "summary_df.to_latex(\"c2db_summary_table.tex\", index=False, float_format=\"%.3f\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "18f4a59b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import cm\n",
    "\n",
    "# Metrics and bar settings\n",
    "metrics = [\"Stable F1\", \"Unstable F1\", \"Macro F1\", \"Weighted F1\"]\n",
    "bar_width = 0.2\n",
    "x = np.arange(len(summary_df))\n",
    "\n",
    "# Get Set2 colormap (as RGBA)\n",
    "cmap = plt.get_cmap(\"tab20\")\n",
    "colors = {metric: cmap(i) for i, metric in enumerate(metrics)}\n",
    "\n",
    "with plt.style.context(\"default\"):\n",
    "    plt.rcParams.update(\n",
    "        {\n",
    "            \"font.size\": SMALL_SIZE,\n",
    "            \"axes.titlesize\": MEDIUM_SIZE,\n",
    "            \"axes.labelsize\": MEDIUM_SIZE,\n",
    "            \"xtick.labelsize\": MEDIUM_SIZE,\n",
    "            \"ytick.labelsize\": MEDIUM_SIZE,\n",
    "            \"legend.fontsize\": SMALL_SIZE,\n",
    "            \"figure.titlesize\": BIGGER_SIZE,\n",
    "        }\n",
    "    )\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(4, 3), layout=\"constrained\")\n",
    "\n",
    "    # Bar positions\n",
    "    positions = {\n",
    "        \"Stable F1\": x - 1.5 * bar_width,\n",
    "        \"Unstable F1\": x - 0.5 * bar_width,\n",
    "        \"Macro F1\": x + 0.5 * bar_width,\n",
    "        \"Weighted F1\": x + 1.5 * bar_width,\n",
    "    }\n",
    "\n",
    "    # Plot each metric with assigned color\n",
    "    for metric, pos in positions.items():\n",
    "        ax.bar(\n",
    "            pos, summary_df[metric], width=bar_width, label=metric, color=colors[metric]\n",
    "        )\n",
    "\n",
    "    ax.set_xlabel(\"Model\")\n",
    "    ax.set_ylabel(\"F1 Score\")\n",
    "    # ax.set_title('F1 Scores by Model and Class')\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels(summary_df[\"Model\"], rotation=45, ha=\"right\")\n",
    "    ax.legend(ncols=2, bbox_to_anchor=(0.5, 1), loc=\"upper center\", fontsize=SMALL_SIZE)\n",
    "    # ax.legend(ncols=2, fontsize=SMALL_SIZE)\n",
    "    ax.spines[[\"top\", \"right\"]].set_visible(False)\n",
    "    plt.tight_layout()\n",
    "    plt.ylim(0, 0.9)\n",
    "    plt.grid(axis=\"y\", linestyle=\"--\", alpha=0.6)\n",
    "\n",
    "    plt.savefig(\"c2db_f1_bar.pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c50f705",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlip-arena",
   "language": "python",
   "name": "mlip-arena"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}