{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'c:\\\\mlops projects\\\\text-summarization'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from pathlib import Path\n",
    "@dataclass(frozen=True)\n",
    "class ModelTrainerConfig:\n",
    "    root_dir : Path\n",
    "    data_path : Path\n",
    "    model_ckpt  : Path\n",
    "    num_train_epochs : int\n",
    "    warmup_steps : int\n",
    "    per_device_train_batch_size : int\n",
    "    weight_decay : float\n",
    "    logging_steps : int\n",
    "    evaluation_strategy: str\n",
    "    eval_steps: int\n",
    "    save_steps: float\n",
    "    gradient_accumulation_steps: int"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from textsummarizer.constants import *\n",
    "from textsummarizer.utils.common import read_yaml,create_directories\n",
    "\n",
    "class ConfigurationManager:\n",
    "    def __init__(\n",
    "        self,\n",
    "        config_filepath = CONFIG_FILE_PATH,\n",
    "        params_filepath = PARAMS_FILE_PATH):\n",
    "\n",
    "        self.config = read_yaml(config_filepath)\n",
    "        self.params = read_yaml(params_filepath)\n",
    "\n",
    "        create_directories([self.config.artifacts_root])\n",
    "        \n",
    "        \n",
    "    def get_model_trainer_config(self) -> ModelTrainerConfig:\n",
    "        config = self.config.model_trainer\n",
    "        params = self.params.TrainingArguments\n",
    "\n",
    "        create_directories([config.root_dir])\n",
    "        \n",
    "        \n",
    "        model_trainer_config = ModelTrainerConfig(\n",
    "            root_dir  = config.root_dir,\n",
    "            data_path = config.data_path,\n",
    "            model_ckpt = config.model_ckpt,\n",
    "            num_train_epochs =params.num_train_epochs,\n",
    "            warmup_steps =params.warmup_steps,\n",
    "            per_device_train_batch_size = params.per_device_train_batch_size,\n",
    "            weight_decay = params.weight_decay,\n",
    "            logging_steps = params.logging_steps,\n",
    "            evaluation_strategy =params.evaluation_strategy,\n",
    "            eval_steps =params.eval_steps,\n",
    "            save_steps =  params.save_steps,\n",
    "            gradient_accumulation_steps = params.gradient_accumulation_steps\n",
    "        )\n",
    "        \n",
    "        return model_trainer_config\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "from transformers import DataCollatorForSeq2Seq\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "from datasets import load_dataset, load_from_disk\n",
    "import torch\n",
    "import  os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelTrainer:\n",
    "    def __init__(self, config : ModelTrainerConfig):\n",
    "        self.config = config\n",
    "        os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
    "        \n",
    "        \n",
    "    def train(self):\n",
    "        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "        tokenizer = AutoTokenizer.from_pretrained(self.config.model_ckpt)\n",
    "        model_pegasus = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_ckpt).to(device)\n",
    "        seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model_pegasus)\n",
    "        \n",
    "        #loading data \n",
    "        dataset_samsum_pt = load_from_disk(self.config.data_path)\n",
    "        \n",
    "        \n",
    "        trainer_args = TrainingArguments(\n",
    "            output_dir=self.config.root_dir, num_train_epochs=self.config.num_train_epochs, warmup_steps=self.config.warmup_steps,\n",
    "            per_device_train_batch_size=self.config.per_device_train_batch_size, per_device_eval_batch_size=self.config.per_device_train_batch_size,\n",
    "            weight_decay=self.config.weight_decay, logging_steps=self.config.logging_steps,\n",
    "            evaluation_strategy=self.config.evaluation_strategy, eval_steps=self.config.eval_steps, save_steps=1e6,\n",
    "            gradient_accumulation_steps=self.config.gradient_accumulation_steps,\n",
    "            report_to=\"none\"\n",
    "            \n",
    "        ) \n",
    "        \n",
    "        \n",
    "        trainer = Trainer(model=model_pegasus, args=trainer_args,\n",
    "                  tokenizer=tokenizer, data_collator=seq2seq_data_collator,\n",
    "                  train_dataset=dataset_samsum_pt[\"test\"], \n",
    "                  eval_dataset=dataset_samsum_pt[\"validation\"])\n",
    "        \n",
    "        \n",
    "        trainer.train()\n",
    "\n",
    "        ## Save model\n",
    "        model_pegasus.save_pretrained(os.path.join(self.config.root_dir,\"pegasus-samsum-model\"))\n",
    "        ## Save tokenizer\n",
    "        tokenizer.save_pretrained(os.path.join(self.config.root_dir,\"tokenizer\"))\n",
    "        \n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    config = ConfigurationManager()\n",
    "    model_trainer_config = config.get_model_trainer_config()\n",
    "    model_trainer_config = ModelTrainer(config=model_trainer_config)\n",
    "    model_trainer_config.train()\n",
    "except Exception as e:\n",
    "    raise e"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}