{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import seaborn as sns \n", "from matplotlib import pyplot as plt\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "sft_log_file = '../logs/sft_train_log_20231211-2250.csv'\n", "dpo_log_file = '../logs/dpo_train_log_20231213-0214.csv'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Unnamed: 0</th>\n", " <th>epoch</th>\n", " <th>learning_rate</th>\n", " <th>loss</th>\n", " <th>step</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>0.00</td>\n", " <td>1.400000e-08</td>\n", " <td>2.5986</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>0.00</td>\n", " <td>1.400000e-06</td>\n", " <td>2.6353</td>\n", " <td>100</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>0.01</td>\n", " <td>2.800000e-06</td>\n", " <td>2.4905</td>\n", " <td>200</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>0.01</td>\n", " <td>4.200000e-06</td>\n", " <td>2.3610</td>\n", " <td>300</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>0.01</td>\n", " <td>5.600000e-06</td>\n", " <td>2.2837</td>\n", " <td>400</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Unnamed: 0 epoch learning_rate loss step\n", "0 0 0.00 1.400000e-08 2.5986 1\n", "1 1 0.00 1.400000e-06 2.6353 100\n", "2 2 0.01 2.800000e-06 2.4905 200\n", "3 3 0.01 4.200000e-06 2.3610 300\n", "4 4 0.01 5.600000e-06 2.2837 400" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sft_df = pd.read_csv(sft_log_file)\n", "dpo_df = pd.read_csv(dpo_log_file)\n", "sft_df.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# plt.title('learning_rate')\n", "sns.lineplot(\n", " x=\"step\", \n", " y=\"learning_rate\", \n", " data=sft_df,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.title('sft loss')\n", "sns.lineplot(\n", " x=\"step\", \n", " y=\"loss\", \n", " color='dodgerblue',\n", " data=sft_df,\n", " )\n", "plt.savefig('../img/sft_loss.png')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Unnamed: 0</th>\n", " <th>epoch</th>\n", " <th>learning_rate</th>\n", " <th>logits/chosen</th>\n", " <th>logits/rejected</th>\n", " <th>logps/chosen</th>\n", " <th>logps/rejected</th>\n", " <th>loss</th>\n", " <th>rewards/accuracies</th>\n", " <th>rewards/chosen</th>\n", " <th>rewards/margins</th>\n", " <th>rewards/rejected</th>\n", " <th>step</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>0.00</td>\n", " <td>1.000000e-08</td>\n", " <td>-3.525447</td>\n", " <td>-3.550683</td>\n", " <td>-256.702698</td>\n", " <td>-143.308243</td>\n", " <td>0.7689</td>\n", " <td>0.437500</td>\n", " <td>-0.044875</td>\n", " <td>-0.072844</td>\n", " <td>0.027969</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>0.01</td>\n", " <td>2.000000e-07</td>\n", " <td>-3.509013</td>\n", " <td>-3.557282</td>\n", " <td>-270.281708</td>\n", " <td>-150.850433</td>\n", " <td>0.7438</td>\n", " <td>0.486842</td>\n", " <td>0.002034</td>\n", " <td>-0.020194</td>\n", " <td>0.022228</td>\n", " <td>20</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>0.01</td>\n", " <td>4.000000e-07</td>\n", " <td>-3.509622</td>\n", " <td>-3.544898</td>\n", " <td>-286.783966</td>\n", " <td>-162.946915</td>\n", " <td>0.7038</td>\n", " <td>0.529688</td>\n", " <td>0.024229</td>\n", " <td>0.046643</td>\n", " <td>-0.022414</td>\n", " <td>40</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>0.02</td>\n", " <td>6.000000e-07</td>\n", " <td>-3.521220</td>\n", " <td>-3.554179</td>\n", " <td>-267.424896</td>\n", " <td>-151.984573</td>\n", " <td>0.7218</td>\n", " <td>0.507812</td>\n", " <td>0.004973</td>\n", " <td>0.008775</td>\n", " <td>-0.003803</td>\n", " <td>60</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>0.03</td>\n", " <td>8.000000e-07</td>\n", " <td>-3.513215</td>\n", " <td>-3.551011</td>\n", " <td>-281.538208</td>\n", " <td>-157.784546</td>\n", " <td>0.6995</td>\n", " <td>0.548437</td>\n", " <td>0.057179</td>\n", " <td>0.069537</td>\n", " <td>-0.012358</td>\n", " <td>80</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Unnamed: 0 epoch learning_rate logits/chosen logits/rejected \\\n", "0 0 0.00 1.000000e-08 -3.525447 -3.550683 \n", "1 1 0.01 2.000000e-07 -3.509013 -3.557282 \n", "2 2 0.01 4.000000e-07 -3.509622 -3.544898 \n", "3 3 0.02 6.000000e-07 -3.521220 -3.554179 \n", "4 4 0.03 8.000000e-07 -3.513215 -3.551011 \n", "\n", " logps/chosen logps/rejected loss rewards/accuracies rewards/chosen \\\n", "0 -256.702698 -143.308243 0.7689 0.437500 -0.044875 \n", "1 -270.281708 -150.850433 0.7438 0.486842 0.002034 \n", "2 -286.783966 -162.946915 0.7038 0.529688 0.024229 \n", "3 -267.424896 -151.984573 0.7218 0.507812 0.004973 \n", "4 -281.538208 -157.784546 0.6995 0.548437 0.057179 \n", "\n", " rewards/margins rewards/rejected step \n", "0 -0.072844 0.027969 1 \n", "1 -0.020194 0.022228 20 \n", "2 0.046643 -0.022414 40 \n", "3 0.008775 -0.003803 60 \n", "4 0.069537 -0.012358 80 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dpo_df.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.title('dpo loss')\n", "sns.lineplot(\n", " x=\"step\", \n", " y=\"loss\", \n", " color='orange',\n", " data=dpo_df[0: 6000 // 20], # 只使用了到6000步的checkpoit,后面的有过拟合迹象\n", " )\n", "plt.savefig('../img/dpo_loss.png')" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import sys, os\n", "root = os.path.realpath('.').replace('\\\\','/').split('/')[0: -1]\n", "root = '/'.join(root)\n", "sys.path.append(root)\n", "\n", "from model.infer import ChatBot\n", "from config import InferConfig\n", "\n", "bot = ChatBot(InferConfig())\n", "model = bot.model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model parameters size: 210.19 M = 0.21B\n", "GPU memory used: 0.40GB\n" ] } ], "source": [ "param_size = sum([p.numel() for p in model.parameters()]) / 1000 / 1000\n", "print('model parameters size: {:.2f} M = {:.2f}B'.format( param_size , param_size / 1000))\n", "\n", "print('GPU memory used: {:.2f}GB'.format(torch.cuda.memory_allocated() / (1024 ** 3)))" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }