{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "from pathlib import Path\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from mlip_arena.models import REGISTRY\n",
    "from mlip_arena.tasks.stability.input import get_atoms_from_db\n",
    "\n",
    "RUN_DIR = Path(\".\").resolve()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compositions = []\n",
    "sizes = []\n",
    "for atoms in tqdm(get_atoms_from_db(\"random-mixture.db\")):\n",
    "    if len(atoms) == 0:\n",
    "        continue\n",
    "    compositions.append(atoms.get_chemical_formula())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pymatviz as pmv\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "fig = pmv.ptable_heatmap(\n",
    "    pmv.count_elements(compositions[:1000]),\n",
    "    colormap=\"GnBu\",\n",
    "    log=True,\n",
    "    return_type=\"figure\",\n",
    ")\n",
    "\n",
    "plt.savefig(\"../figures/stability-element-counts.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from ase import Atoms\n",
    "\n",
    "\n",
    "def get_runtime_stats(traj: list[Atoms], atoms0: Atoms):\n",
    "    restarts = []\n",
    "    steps, times = [], []\n",
    "    Ts, Ps, Es, KEs = [], [], [], []\n",
    "    timesteps = []\n",
    "    com_drifts = []\n",
    "\n",
    "    for atoms in tqdm(traj):\n",
    "        assert isinstance(atoms, Atoms)\n",
    "        try:\n",
    "            energy = atoms.get_potential_energy()\n",
    "            assert np.isfinite(energy), f\"invalid energy: {energy}\"\n",
    "            # assert np.all(~np.isnan(atoms.get_forces())), f\"invalid forces: {atoms.get_forces()}\"\n",
    "            # assert np.all(~np.isnan(atoms.get_stress())), f\"invalid stress: {atoms.get_stress()}\"\n",
    "        except Exception:\n",
    "            continue\n",
    "\n",
    "        restarts.append(atoms.info[\"restart\"])\n",
    "        times.append(atoms.info[\"datetime\"])\n",
    "        steps.append(atoms.info[\"step\"])\n",
    "        Es.append(energy)\n",
    "        KEs.append(atoms.get_kinetic_energy())\n",
    "        Ts.append(atoms.get_temperature())\n",
    "        try:\n",
    "            Ps.append(atoms.get_stress()[:3].mean())\n",
    "        except:\n",
    "            pass\n",
    "        com_drifts.append(\n",
    "            (atoms.get_center_of_mass() - atoms0.get_center_of_mass()).tolist()\n",
    "        )\n",
    "\n",
    "    restarts = np.array(restarts)\n",
    "    times = np.array(times)\n",
    "    steps = np.array(steps)\n",
    "\n",
    "    # Identify unique blocks\n",
    "    unique_restarts = np.unique(restarts)\n",
    "\n",
    "    total_time_seconds = 0\n",
    "    total_steps = 0\n",
    "\n",
    "    # Iterate over unique blocks to calculate averages\n",
    "    for block in unique_restarts:\n",
    "        # Get the indices corresponding to the current block\n",
    "        # indices = np.where(restarts == block)[0]\n",
    "        indices = restarts == block\n",
    "        # Extract the corresponding data values\n",
    "        block_time = times[indices][-1] - times[indices][0]\n",
    "        total_time_seconds += block_time.total_seconds()\n",
    "        total_steps += steps[indices][-1] - steps[indices][0]\n",
    "\n",
    "    target_steps = traj[0].info[\"target_steps\"]\n",
    "    natoms = len(traj[0])\n",
    "\n",
    "    return {\n",
    "        \"natoms\": natoms,\n",
    "        \"total_time_seconds\": total_time_seconds,\n",
    "        \"total_steps\": total_steps,\n",
    "        \"steps_per_second\": total_steps / total_time_seconds\n",
    "        if total_time_seconds != 0\n",
    "        else 0,\n",
    "        \"seconds_per_step\": total_time_seconds / total_steps\n",
    "        if total_steps != 0\n",
    "        else float(\"inf\"),\n",
    "        \"seconds_per_step_per_atom\": total_time_seconds / total_steps / natoms\n",
    "        if total_steps != 0\n",
    "        else float(\"inf\"),\n",
    "        \"energies\": Es,\n",
    "        \"kinetic_energies\": KEs,\n",
    "        \"temperatures\": Ts,\n",
    "        \"pressures\": Ps,\n",
    "        \"target_steps\": target_steps,\n",
    "        \"final_step\": steps[-1] if len(steps) != 0 else 0,\n",
    "        \"timestep\": steps,\n",
    "        \"com_drifts\": com_drifts,\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.colors as pcolors\n",
    "\n",
    "mlip_methods = [\n",
    "    model\n",
    "    for model, metadata in REGISTRY.items()\n",
    "    if \"stability\" in metadata.get(\"gpu-tasks\", [])\n",
    "]\n",
    "\n",
    "all_attributes = dir(pcolors.qualitative)\n",
    "color_palettes = {\n",
    "    attr: getattr(pcolors.qualitative, attr)\n",
    "    for attr in all_attributes\n",
    "    if isinstance(getattr(pcolors.qualitative, attr), list)\n",
    "}\n",
    "color_palettes.pop(\"__all__\", None)\n",
    "\n",
    "palette_names = list(color_palettes.keys())\n",
    "palette_colors = list(color_palettes.values())\n",
    "palette_name = \"T10\"  # \"Plotly\"\n",
    "color_sequence = color_palettes[palette_name]  # type: ignore\n",
    "\n",
    "method_color_mapping = {\n",
    "    method: color_sequence[i % len(color_sequence)]\n",
    "    for i, method in enumerate(mlip_methods)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NPT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# from huggingface_hub import HfApi\n",
    "import seaborn as sns\n",
    "from ase import units\n",
    "from ase.io import read\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "df = pd.DataFrame()\n",
    "\n",
    "for model in mlip_methods:\n",
    "    # if \"stability\" not in REGISTRY[model]['gpu-tasks']:\n",
    "    #     continue\n",
    "\n",
    "    files = glob.glob(str(RUN_DIR / REGISTRY[model][\"family\"] / f\"{model}_*npt.traj\"))\n",
    "\n",
    "    for i, file in enumerate(files):\n",
    "        try:\n",
    "            traj = read(file, index=\":\")\n",
    "        except Exception as e:\n",
    "            print(f\"Error reading {file}: {e}\")\n",
    "            continue\n",
    "\n",
    "        try:\n",
    "            stats = get_runtime_stats(traj, atoms0=traj[0])\n",
    "        except Exception as e:\n",
    "            print(f\"Error processing {file}: {e}\")\n",
    "            continue\n",
    "\n",
    "        df = pd.concat(\n",
    "            [\n",
    "                df,\n",
    "                pd.DataFrame(\n",
    "                    {\n",
    "                        \"model\": model,\n",
    "                        \"formula\": traj[0].get_chemical_formula(),\n",
    "                        \"normalized_timestep\": stats[\"timestep\"]\n",
    "                        / stats[\"target_steps\"],\n",
    "                        \"normalized_final_step\": stats[\"final_step\"]\n",
    "                        / stats[\"target_steps\"],\n",
    "                        \"pressure\": np.array(stats[\"pressures\"]) / units.GPa,\n",
    "                    }\n",
    "                    | stats\n",
    "                ),\n",
    "            ],\n",
    "            ignore_index=True,\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "# import scipy.optimize as opt\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.optimize import curve_fit\n",
    "\n",
    "\n",
    "# Define the power-law fitting function\n",
    "def power_law(x, a, n):\n",
    "    return a * np.power(x, n)\n",
    "\n",
    "\n",
    "df.rename(\n",
    "    columns={\n",
    "        \"final_step\": \"Total steps\",\n",
    "        \"model\": \"Model\",\n",
    "    },\n",
    "    inplace=True,\n",
    ")\n",
    "\n",
    "with plt.style.context(\"default\"):\n",
    "\n",
    "    SMALL_SIZE = 8\n",
    "\n",
    "    fig, axes = plt.subplot_mosaic(\n",
    "        \"\"\"\n",
    "        ao\n",
    "        \"\"\",\n",
    "        constrained_layout=True,\n",
    "        figsize=(6, 3),\n",
    "        width_ratios=[1, 3],\n",
    "    )\n",
    "\n",
    "    iax = \"o\"\n",
    "    ax = axes.pop(iax)\n",
    "\n",
    "    sns.scatterplot(\n",
    "        data=df,\n",
    "        x=\"natoms\",\n",
    "        y=\"steps_per_second\",\n",
    "        size=\"Total steps\",\n",
    "        hue=\"Model\",\n",
    "        ax=ax,\n",
    "        palette=method_color_mapping,\n",
    "        sizes=(1, 50),\n",
    "        # alpha=0.5\n",
    "    )\n",
    "\n",
    "    # Fit and plot power-law regression for each model\n",
    "    for model, data in df.groupby(\"Model\"):\n",
    "        data.dropna(subset=[\"steps_per_second\"], inplace=True)\n",
    "\n",
    "        popt, pcov = curve_fit(power_law, data[\"natoms\"], data[\"steps_per_second\"])\n",
    "\n",
    "        # Generate smooth curve\n",
    "        # x_fit = np.logspace(np.log10(xdata.min()), np.log10(xdata.max()), 100)\n",
    "        # y_fit = power_law(x_fit, a_fit, n_fit)\n",
    "\n",
    "        x = np.linspace(data[\"natoms\"].min(), data[\"natoms\"].max(), 100)\n",
    "\n",
    "        # Plot regression line\n",
    "        ax.plot(\n",
    "            x,\n",
    "            power_law(x, *popt),\n",
    "            c=method_color_mapping[model],\n",
    "            # label=f\"{model} (y={a_fit:.2e}x^{n_fit:.2f})\",\n",
    "            linestyle=\"-\",\n",
    "        )\n",
    "\n",
    "    # sns.lineplot(\n",
    "    #     data=df,\n",
    "    #     x='natoms',\n",
    "    #     y='steps_per_second',\n",
    "    #     # size='Total steps',\n",
    "    #     hue='Model',\n",
    "    #     ax=ax,\n",
    "    #     palette=method_color_mapping,\n",
    "    #     alpha=0.5,\n",
    "    #     # err_style=\"bars\"\n",
    "    # )\n",
    "\n",
    "    ax.set(\n",
    "        xlabel=\"Number of atoms\",\n",
    "        xscale=\"log\",\n",
    "        ylabel=\"Steps per second\",\n",
    "        yscale=\"log\",\n",
    "    )\n",
    "    ax.spines[\"right\"].set_visible(False)\n",
    "    ax.spines[\"top\"].set_visible(False)\n",
    "    ax.grid(alpha=0.25)\n",
    "    ax.legend(\n",
    "        loc=\"upper left\", bbox_to_anchor=(1.0, 1.0), fontsize=\"x-small\", frameon=False\n",
    "    )\n",
    "\n",
    "    fisrt = 80\n",
    "\n",
    "    for k, df_model in df.groupby(\"Model\"):\n",
    "        ax = axes[\"a\"]\n",
    "\n",
    "        df_model.drop_duplicates([\"formula\"], inplace=True)\n",
    "        df_model = df_model[df_model[\"formula\"].isin(compositions[:fisrt])].copy()\n",
    "        print(k, len(df_model))\n",
    "\n",
    "        # Compute histogram\n",
    "        bins = np.linspace(0, 1, 50)  # 50 bins from 0 to 1\n",
    "        hist, bin_edges = np.histogram(\n",
    "            df_model[\"normalized_final_step\"], bins=bins, density=False\n",
    "        )\n",
    "\n",
    "        # Compute cumulative population\n",
    "        cumulative_population = np.cumsum(hist)\n",
    "\n",
    "        # Midpoints for binning\n",
    "        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2\n",
    "\n",
    "        sns.lineplot(\n",
    "            x=bin_centers[:-1],\n",
    "            y=(cumulative_population[-1] - cumulative_population[:-1]) / first * 100,\n",
    "            ax=axes[\"a\"],\n",
    "            # label=k,\n",
    "            color=method_color_mapping[k],\n",
    "            # palette=method_color_mapping\n",
    "        )\n",
    "\n",
    "    ax_main = axes[\"a\"]\n",
    "    ax_main.spines[\"right\"].set_visible(False)\n",
    "    ax_temp = ax_main.twiny()\n",
    "    ax_pressure = ax_main.twiny()\n",
    "\n",
    "    # === Plot styling and range ===\n",
    "    ax_main.set_xlim(0, 1)\n",
    "    ax_main.set_ylim(0, 100)\n",
    "    # ax_main.set_yticks(range(0, 81, 20))\n",
    "    ax_main.set_ylabel(\"valid runs (%)\")\n",
    "\n",
    "\n",
    "    # === Set top x-axis: Time (ps) ===\n",
    "    ax_main.set_xticks([0, 1])\n",
    "    ax_main.set_xticklabels([0, 10])\n",
    "    ax_main.set_xlabel(\"Time (ps)\")\n",
    "    ax_main.xaxis.set_label_position(\"top\")\n",
    "    ax_main.xaxis.tick_top()\n",
    "    ax_main.spines[\"top\"].set_position((\"outward\", 5))  # Keep just below plot\n",
    "    # ax_main.tick_params(axis=\"x\", top=True, labeltop=True, bottom=False, labelbottom=False)\n",
    "\n",
    "    # === Bottom axis: Temperature ===\n",
    "    ax_temp.set_xlim(ax_main.get_xlim())\n",
    "    ax_temp.set_xticks([0, 1])\n",
    "    ax_temp.set_xticklabels([\"300 K\", \"3000 K\"])\n",
    "    # ax_temp.set_xlabel(\"Temperature (K)\")\n",
    "    ax_temp.xaxis.set_ticks_position(\"bottom\")\n",
    "    ax_temp.xaxis.set_label_position(\"bottom\")\n",
    "    ax_temp.spines[\"right\"].set_visible(False)\n",
    "    ax_temp.spines[\"top\"].set_visible(False)\n",
    "    ax_temp.spines[\"bottom\"].set_position((\"outward\", 5))  # Keep just below plot\n",
    "\n",
    "    # === Lower bottom axis: Pressure ===\n",
    "    ax_pressure.set_xlim(ax_main.get_xlim())\n",
    "    ax_pressure.set_xticks([0, 1])\n",
    "    ax_pressure.set_xticklabels([\"0 GPa\", \"500 GPa\"])\n",
    "    # ax_pressure.set_xlabel(\"Pressure (GPa)\")\n",
    "    ax_pressure.xaxis.set_ticks_position(\"bottom\")\n",
    "    ax_pressure.xaxis.set_label_position(\"bottom\")\n",
    "    ax_pressure.spines[\"right\"].set_visible(False)\n",
    "    ax_pressure.spines[\"top\"].set_visible(False)\n",
    "    ax_pressure.spines[\"bottom\"].set_position((\"outward\", 25))  # Push further down\n",
    "\n",
    "    # # === Clean up main axis ===\n",
    "    ax_main.legend_ = None\n",
    "\n",
    "    plt.savefig(\"stability-and-speed-npt-loglog.pdf\", bbox_inches=\"tight\")\n",
    "    plt.savefig(\n",
    "        \"stability-and-speed-npt-loglog.png\", bbox_inches=\"tight\", dpi=330\n",
    "    )\n",
    "\n",
    "    # plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NVT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "# from huggingface_hub import HfApi\n",
    "import seaborn as sns\n",
    "from ase import units\n",
    "from ase.io import read\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "df = pd.DataFrame()\n",
    "\n",
    "for model in mlip_methods:\n",
    "    # if \"stability\" not in REGISTRY[model]['gpu-tasks']:\n",
    "    #     continue\n",
    "\n",
    "    files = glob.glob(str(RUN_DIR / REGISTRY[model][\"family\"] / f\"{model}_*nvt.traj\"))\n",
    "\n",
    "    for i, file in enumerate(files):\n",
    "        try:\n",
    "            traj = read(file, index=\":\")\n",
    "        except Exception as e:\n",
    "            print(f\"Error reading {file}: {e}\")\n",
    "            continue\n",
    "\n",
    "        try:\n",
    "            stats = get_runtime_stats(traj, atoms0=traj[0])\n",
    "        except Exception as e:\n",
    "            print(f\"Error processing {file}: {e}\")\n",
    "            continue\n",
    "\n",
    "        df = pd.concat(\n",
    "            [\n",
    "                df,\n",
    "                pd.DataFrame(\n",
    "                    {\n",
    "                        \"model\": model,\n",
    "                        \"formula\": traj[0].get_chemical_formula(),\n",
    "                        \"normalized_timestep\": stats[\"timestep\"]\n",
    "                        / stats[\"target_steps\"],\n",
    "                        \"normalized_final_step\": stats[\"final_step\"]\n",
    "                        / stats[\"target_steps\"],\n",
    "                        \"pressure\": np.array(stats[\"pressures\"]) / units.GPa,\n",
    "                    }\n",
    "                    | stats\n",
    "                ),\n",
    "            ],\n",
    "            ignore_index=True,\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# import scipy.optimize as opt\n",
    "import seaborn as sns\n",
    "from scipy.optimize import curve_fit\n",
    "\n",
    "\n",
    "# Define the power-law fitting function\n",
    "def power_law(x, a, n):\n",
    "    return a * np.power(x, n)\n",
    "\n",
    "\n",
    "df.rename(\n",
    "    columns={\n",
    "        \"final_step\": \"Total steps\",\n",
    "        \"model\": \"Model\",\n",
    "    },\n",
    "    inplace=True,\n",
    ")\n",
    "\n",
    "with plt.style.context(\"default\"):\n",
    "    fig, axes = plt.subplot_mosaic(\n",
    "        \"\"\"\n",
    "        ao\n",
    "        \"\"\",\n",
    "        constrained_layout=True,\n",
    "        figsize=(6, 3),\n",
    "        width_ratios=[1, 3],\n",
    "    )\n",
    "\n",
    "    iax = \"o\"\n",
    "    ax = axes.pop(iax)\n",
    "\n",
    "    sns.scatterplot(\n",
    "        data=df,\n",
    "        x=\"natoms\",\n",
    "        y=\"steps_per_second\",\n",
    "        size=\"Total steps\",\n",
    "        hue=\"Model\",\n",
    "        ax=ax,\n",
    "        palette=method_color_mapping,\n",
    "        sizes=(1, 50),\n",
    "        # alpha=0.5\n",
    "    )\n",
    "\n",
    "    # Fit and plot power-law regression for each model\n",
    "    for model, data in df.groupby(\"Model\"):\n",
    "        data.dropna(subset=[\"steps_per_second\"], inplace=True)\n",
    "\n",
    "        popt, pcov = curve_fit(power_law, data[\"natoms\"], data[\"steps_per_second\"])\n",
    "\n",
    "        # Generate smooth curve\n",
    "        # x_fit = np.logspace(np.log10(xdata.min()), np.log10(xdata.max()), 100)\n",
    "        # y_fit = power_law(x_fit, a_fit, n_fit)\n",
    "\n",
    "        x = np.linspace(data[\"natoms\"].min(), data[\"natoms\"].max(), 100)\n",
    "\n",
    "        # Plot regression line\n",
    "        ax.plot(\n",
    "            x,\n",
    "            power_law(x, *popt),\n",
    "            c=method_color_mapping[model],\n",
    "            # label=f\"{model} (y={a_fit:.2e}x^{n_fit:.2f})\",\n",
    "            linestyle=\"-\",\n",
    "        )\n",
    "\n",
    "    # sns.lineplot(\n",
    "    #     data=df,\n",
    "    #     x='natoms',\n",
    "    #     y='steps_per_second',\n",
    "    #     # size='Total steps',\n",
    "    #     hue='Model',\n",
    "    #     ax=ax,\n",
    "    #     palette=method_color_mapping,\n",
    "    #     alpha=0.5,\n",
    "    #     # err_style=\"bars\"\n",
    "    # )\n",
    "\n",
    "    ax.set(\n",
    "        xlabel=\"Number of atoms\",\n",
    "        xscale=\"log\",\n",
    "        ylabel=\"Steps per second\",\n",
    "        yscale=\"log\",\n",
    "    )\n",
    "    ax.spines[\"right\"].set_visible(False)\n",
    "    ax.spines[\"top\"].set_visible(False)\n",
    "    ax.grid(alpha=0.25)\n",
    "    ax.legend(\n",
    "        loc=\"upper left\", bbox_to_anchor=(1.0, 1.0), fontsize=\"x-small\", frameon=False\n",
    "    )\n",
    "\n",
    "    fisrt = 120\n",
    "\n",
    "    for k, df_model in df.groupby(\"Model\"):\n",
    "        ax = axes[\"a\"]\n",
    "\n",
    "        df_model.drop_duplicates([\"formula\"], inplace=True)\n",
    "        df_model = df_model[df_model[\"formula\"].isin(compositions[:fisrt])].copy()\n",
    "\n",
    "        # Compute histogram\n",
    "        bins = np.linspace(0, 1, 50)  # 50 bins from 0 to 1\n",
    "        hist, bin_edges = np.histogram(\n",
    "            df_model[\"normalized_final_step\"], bins=bins, density=False\n",
    "        )\n",
    "\n",
    "        # Compute cumulative population\n",
    "        cumulative_population = np.cumsum(hist)\n",
    "\n",
    "        # Midpoints for binning\n",
    "        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2\n",
    "\n",
    "        sns.lineplot(\n",
    "            x=bin_centers[:-1],\n",
    "            y=(cumulative_population[-1] - cumulative_population[:-1]) / fisrt * 100,\n",
    "            ax=axes[\"a\"],\n",
    "            # label=k,\n",
    "            color=method_color_mapping[k],\n",
    "            # palette=method_color_mapping\n",
    "        )\n",
    "\n",
    "    ax_main = axes[\"a\"]\n",
    "    ax_main.spines[\"right\"].set_visible(False)\n",
    "    ax_temp = ax_main.twiny()\n",
    "    # ax_pressure = ax_main.twiny()\n",
    "\n",
    "    # === Plot styling and range ===\n",
    "    ax_main.set_xlim(0, 1)\n",
    "    # ax_main.set_ylim(0, 100)\n",
    "    # ax_main.set_yticks(range(0, 81, 20))\n",
    "    ax_main.set_ylabel(\"valid runs (%)\")\n",
    "\n",
    "\n",
    "    # === Set top x-axis: Time (ps) ===\n",
    "    ax_main.set_xticks([0, 1])\n",
    "    ax_main.set_xticklabels([0, 10])\n",
    "    ax_main.set_xlabel(\"Time (ps)\")\n",
    "    ax_main.xaxis.set_label_position(\"top\")\n",
    "    ax_main.xaxis.tick_top()\n",
    "    ax_main.spines[\"top\"].set_position((\"outward\", 5))  # Keep just below plot\n",
    "    # ax_main.tick_params(axis=\"x\", top=True, labeltop=True, bottom=False, labelbottom=False)\n",
    "\n",
    "    # === Bottom axis: Temperature ===\n",
    "    ax_temp.set_xlim(ax_main.get_xlim())\n",
    "    ax_temp.set_xticks([0, 1])\n",
    "    ax_temp.set_xticklabels([\"300 K\", \"3000 K\"])\n",
    "    # ax_temp.set_xlabel(\"Temperature (K)\")\n",
    "    ax_temp.xaxis.set_ticks_position(\"bottom\")\n",
    "    ax_temp.xaxis.set_label_position(\"bottom\")\n",
    "    ax_temp.spines[\"right\"].set_visible(False)\n",
    "    ax_temp.spines[\"top\"].set_visible(False)\n",
    "    ax_temp.spines[\"bottom\"].set_position((\"outward\", 5))  # Keep just below plot\n",
    "\n",
    "    # # === Clean up main axis ===\n",
    "    ax_main.legend_ = None\n",
    "\n",
    "    plt.savefig(\"stability-and-speed-nvt-loglog.pdf\", bbox_inches=\"tight\")\n",
    "    plt.savefig(\n",
    "        \"stability-and-speed-nvt-loglog.png\", bbox_inches=\"tight\", dpi=330\n",
    "    )\n",
    "\n",
    "    # plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "06905b5dd49e47fb9ca98d2e3a9babb8": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_b3a1e313f7334fa78392cec0476b2a30",
       "style": "IPY_MODEL_9e078e2ba27e449e86ecc1fe59f681ec"
      }
     },
     "0ef76231108146649bcbdceba016aac5": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "51ec40d026074e34a1168f5240228ca8": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "7f2b420195284e4b972e6762dfb960eb": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "width": "20px"
      }
     },
     "9e078e2ba27e449e86ecc1fe59f681ec": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "b3a1e313f7334fa78392cec0476b2a30": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "ce30697246e6491baaa7b1fa21a20f8a": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "cf29764478a34059a68c87b6c46e2972": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_7f2b420195284e4b972e6762dfb960eb",
       "max": 1,
       "style": "IPY_MODEL_0ef76231108146649bcbdceba016aac5",
       "value": 1
      }
     },
     "e02fd4d9b9d04c87887a3903274c794a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_ce30697246e6491baaa7b1fa21a20f8a",
       "style": "IPY_MODEL_f40f4df44b3f4b658fa4f7204624f9cf",
       "value": " 1764/? [00:01<00:00, 1759.99it/s]"
      }
     },
     "f088c4da133d406694657239bcefbbe0": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_06905b5dd49e47fb9ca98d2e3a9babb8",
        "IPY_MODEL_cf29764478a34059a68c87b6c46e2972",
        "IPY_MODEL_e02fd4d9b9d04c87887a3903274c794a"
       ],
       "layout": "IPY_MODEL_51ec40d026074e34a1168f5240228ca8"
      }
     },
     "f40f4df44b3f4b658fa4f7204624f9cf": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}