{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns \n",
    "from matplotlib import pyplot as plt\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sft_log_file = '../logs/sft_train_log_20231211-2250.csv'\n",
    "dpo_log_file = '../logs/dpo_train_log_20231213-0214.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>epoch</th>\n",
       "      <th>learning_rate</th>\n",
       "      <th>loss</th>\n",
       "      <th>step</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.400000e-08</td>\n",
       "      <td>2.5986</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.400000e-06</td>\n",
       "      <td>2.6353</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0.01</td>\n",
       "      <td>2.800000e-06</td>\n",
       "      <td>2.4905</td>\n",
       "      <td>200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.01</td>\n",
       "      <td>4.200000e-06</td>\n",
       "      <td>2.3610</td>\n",
       "      <td>300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0.01</td>\n",
       "      <td>5.600000e-06</td>\n",
       "      <td>2.2837</td>\n",
       "      <td>400</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0  epoch  learning_rate    loss  step\n",
       "0           0   0.00   1.400000e-08  2.5986     1\n",
       "1           1   0.00   1.400000e-06  2.6353   100\n",
       "2           2   0.01   2.800000e-06  2.4905   200\n",
       "3           3   0.01   4.200000e-06  2.3610   300\n",
       "4           4   0.01   5.600000e-06  2.2837   400"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sft_df = pd.read_csv(sft_log_file)\n",
    "dpo_df = pd.read_csv(dpo_log_file)\n",
    "sft_df.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.title('learning_rate')\n",
    "sns.lineplot(\n",
    "    x=\"step\", \n",
    "    y=\"learning_rate\", \n",
    "    data=sft_df,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('sft loss')\n",
    "sns.lineplot(\n",
    "    x=\"step\", \n",
    "    y=\"loss\", \n",
    "    color='dodgerblue',\n",
    "    data=sft_df,\n",
    "    )\n",
    "plt.savefig('../img/sft_loss.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>epoch</th>\n",
       "      <th>learning_rate</th>\n",
       "      <th>logits/chosen</th>\n",
       "      <th>logits/rejected</th>\n",
       "      <th>logps/chosen</th>\n",
       "      <th>logps/rejected</th>\n",
       "      <th>loss</th>\n",
       "      <th>rewards/accuracies</th>\n",
       "      <th>rewards/chosen</th>\n",
       "      <th>rewards/margins</th>\n",
       "      <th>rewards/rejected</th>\n",
       "      <th>step</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.000000e-08</td>\n",
       "      <td>-3.525447</td>\n",
       "      <td>-3.550683</td>\n",
       "      <td>-256.702698</td>\n",
       "      <td>-143.308243</td>\n",
       "      <td>0.7689</td>\n",
       "      <td>0.437500</td>\n",
       "      <td>-0.044875</td>\n",
       "      <td>-0.072844</td>\n",
       "      <td>0.027969</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.01</td>\n",
       "      <td>2.000000e-07</td>\n",
       "      <td>-3.509013</td>\n",
       "      <td>-3.557282</td>\n",
       "      <td>-270.281708</td>\n",
       "      <td>-150.850433</td>\n",
       "      <td>0.7438</td>\n",
       "      <td>0.486842</td>\n",
       "      <td>0.002034</td>\n",
       "      <td>-0.020194</td>\n",
       "      <td>0.022228</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0.01</td>\n",
       "      <td>4.000000e-07</td>\n",
       "      <td>-3.509622</td>\n",
       "      <td>-3.544898</td>\n",
       "      <td>-286.783966</td>\n",
       "      <td>-162.946915</td>\n",
       "      <td>0.7038</td>\n",
       "      <td>0.529688</td>\n",
       "      <td>0.024229</td>\n",
       "      <td>0.046643</td>\n",
       "      <td>-0.022414</td>\n",
       "      <td>40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.02</td>\n",
       "      <td>6.000000e-07</td>\n",
       "      <td>-3.521220</td>\n",
       "      <td>-3.554179</td>\n",
       "      <td>-267.424896</td>\n",
       "      <td>-151.984573</td>\n",
       "      <td>0.7218</td>\n",
       "      <td>0.507812</td>\n",
       "      <td>0.004973</td>\n",
       "      <td>0.008775</td>\n",
       "      <td>-0.003803</td>\n",
       "      <td>60</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0.03</td>\n",
       "      <td>8.000000e-07</td>\n",
       "      <td>-3.513215</td>\n",
       "      <td>-3.551011</td>\n",
       "      <td>-281.538208</td>\n",
       "      <td>-157.784546</td>\n",
       "      <td>0.6995</td>\n",
       "      <td>0.548437</td>\n",
       "      <td>0.057179</td>\n",
       "      <td>0.069537</td>\n",
       "      <td>-0.012358</td>\n",
       "      <td>80</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0  epoch  learning_rate  logits/chosen  logits/rejected  \\\n",
       "0           0   0.00   1.000000e-08      -3.525447        -3.550683   \n",
       "1           1   0.01   2.000000e-07      -3.509013        -3.557282   \n",
       "2           2   0.01   4.000000e-07      -3.509622        -3.544898   \n",
       "3           3   0.02   6.000000e-07      -3.521220        -3.554179   \n",
       "4           4   0.03   8.000000e-07      -3.513215        -3.551011   \n",
       "\n",
       "   logps/chosen  logps/rejected    loss  rewards/accuracies  rewards/chosen  \\\n",
       "0   -256.702698     -143.308243  0.7689            0.437500       -0.044875   \n",
       "1   -270.281708     -150.850433  0.7438            0.486842        0.002034   \n",
       "2   -286.783966     -162.946915  0.7038            0.529688        0.024229   \n",
       "3   -267.424896     -151.984573  0.7218            0.507812        0.004973   \n",
       "4   -281.538208     -157.784546  0.6995            0.548437        0.057179   \n",
       "\n",
       "   rewards/margins  rewards/rejected  step  \n",
       "0        -0.072844          0.027969     1  \n",
       "1        -0.020194          0.022228    20  \n",
       "2         0.046643         -0.022414    40  \n",
       "3         0.008775         -0.003803    60  \n",
       "4         0.069537         -0.012358    80  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dpo_df.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('dpo loss')\n",
    "sns.lineplot(\n",
    "    x=\"step\", \n",
    "    y=\"loss\", \n",
    "    color='orange',\n",
    "    data=dpo_df[0: 6000 // 20],  # 只使用了到6000步的checkpoit,后面的有过拟合迹象\n",
    "    )\n",
    "plt.savefig('../img/dpo_loss.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys, os\n",
    "root = os.path.realpath('.').replace('\\\\','/').split('/')[0: -1]\n",
    "root = '/'.join(root)\n",
    "sys.path.append(root)\n",
    "\n",
    "from model.infer import ChatBot\n",
    "from config import InferConfig\n",
    "\n",
    "bot = ChatBot(InferConfig())\n",
    "model = bot.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model parameters size: 210.19 M = 0.21B\n",
      "GPU memory used: 0.40GB\n"
     ]
    }
   ],
   "source": [
    "param_size = sum([p.numel() for p in model.parameters()]) / 1000 / 1000\n",
    "print('model parameters size: {:.2f} M = {:.2f}B'.format( param_size , param_size / 1000))\n",
    "\n",
    "print('GPU memory used: {:.2f}GB'.format(torch.cuda.memory_allocated() / (1024 ** 3)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}