{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [], "metadata": { "id": "I9Z5guQ6CDt8" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "PGeicEbzCDw9" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "pip install sentencepiece torch torchvision torchaudio pandas scikit-learn\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "zQFdKxIICD0H", "outputId": "5d35d6a1-a876-4c7f-fee8-4f04888f3854" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.11/dist-packages (0.2.0)\n", "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.5.1+cu124)\n", "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.20.1+cu124)\n", "Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.5.1+cu124)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (2.2.2)\n", "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (1.6.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.17.0)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.12.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.4.2)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.5)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2024.10.0)\n", "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n", " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n", " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n", " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n", " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n", " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n", " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n", " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n", " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n", " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n", " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.1.0)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas) (2025.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas) (2025.1)\n", "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.13.1)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.4.2)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (3.5.0)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", "Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m110.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m89.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m58.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m85.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n", " Attempting uninstall: nvidia-nvjitlink-cu12\n", " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", " Attempting uninstall: nvidia-curand-cu12\n", " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", " Attempting uninstall: nvidia-cufft-cu12\n", " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", " Attempting uninstall: nvidia-cuda-runtime-cu12\n", " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", " Attempting uninstall: nvidia-cuda-cupti-cu12\n", " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", " Attempting uninstall: nvidia-cublas-cu12\n", " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", " Attempting uninstall: nvidia-cusparse-cu12\n", " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", " Attempting uninstall: nvidia-cudnn-cu12\n", " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", " Attempting uninstall: nvidia-cusolver-cu12\n", " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", "Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127\n" ] } ] }, { "cell_type": "code", "source": [ "\n", "\n", "!pip install sentencepiece --quiet" ], "metadata": { "id": "6DbIAMlqNDRK" }, "execution_count": 2, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "\"\"\"\n", "Model File for Roman Urdu Poetry Generation\n", "\n", "This file contains the complete code for:\n", " - Data loading, cleaning, and tokenization using SentencePiece\n", " - Train/Test/Validation split creation\n", " - Dataset and DataLoader creation\n", " - Definition of a BiLSTM Language Model (with 3 layers, dropout, etc.)\n", " - Training, validation, and testing routines\n", " - Saving the trained model weights\n", " - A poetry generation function using nucleus (top-p) sampling with formatted output\n", "\n", "Run this file to train and test the model. The trained weights will be saved to a file and loaded on subsequent runs.\n", "\"\"\"" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 157 }, "id": "DjB6rAwz-D3Q", "outputId": "817edbf7-6063-4c8c-fb49-30b18dd386b5" }, "execution_count": 3, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'\\nModel File for Roman Urdu Poetry Generation\\n\\nThis file contains the complete code for:\\n - Data loading, cleaning, and tokenization using SentencePiece\\n - Train/Test/Validation split creation\\n - Dataset and DataLoader creation\\n - Definition of a BiLSTM Language Model (with 3 layers, dropout, etc.)\\n - Training, validation, and testing routines\\n - Saving the trained model weights\\n - A poetry generation function using nucleus (top-p) sampling with formatted output\\n\\nRun this file to train and test the model. The trained weights will be saved to a file and loaded on subsequent runs.\\n'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 3 } ] }, { "cell_type": "code", "source": [ "# -------------------------\n", "# 1. Import Libraries\n", "# -------------------------\n", "import os\n", "import random\n", "import numpy as np\n", "import pandas as pd\n", "import sentencepiece as spm\n", "import re\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import Dataset, DataLoader\n", "import torch.nn.functional as F\n", "import unicodedata\n", "from sklearn.model_selection import train_test_split" ], "metadata": { "id": "HoqaPLEq-Ega" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "# -------------------------\n", "# 2. Set Random Seeds and Device\n", "# -------------------------\n", "SEED = 42\n", "random.seed(SEED)\n", "np.random.seed(SEED)\n", "torch.manual_seed(SEED)\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(\"Using device:\", device)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "u4Xf1Ck6-H-M", "outputId": "f171c4a8-4e30-4873-ebf9-2782aa3e9bdc" }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Using device: cuda\n" ] } ] }, { "cell_type": "code", "source": [ "# -------------------------\n", "# 3. Load and Clean Dataset\n", "# -------------------------\n", "DATA_PATH = \"Roman-Urdu-Poetry.csv\" # Make sure this file exists in your working directory\n", "df = pd.read_csv(DATA_PATH)\n", "\n", "def remove_diacritics(text: str) -> str:\n", " \"\"\"\n", " Removes Unicode diacritical marks from the text.\n", " \"\"\"\n", " return ''.join(ch for ch in unicodedata.normalize('NFD', text)\n", " if not unicodedata.combining(ch))\n", "\n", "def clean_text(text):\n", " \"\"\"\n", " Cleans the input text by removing diacritics, extra spaces, and unwanted punctuation.\n", " \"\"\"\n", " text = remove_diacritics(text)\n", " text = re.sub(r\"\\s+\", \" \", text)\n", " text = re.sub(r\"[^\\w\\s\\.\\,\\;\\:\\'\\?\\!\\-]+\", \"\", text)\n", " return text.strip()\n", "\n", "df[\"Poetry\"] = df[\"Poetry\"].astype(str).apply(clean_text)\n", "texts = df[\"Poetry\"].tolist()\n", "print(f\"Total number of poetry lines: {len(texts)}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MYJTunkz-LDb", "outputId": "82609d66-3e91-4795-eac5-251bf9bf8dd1" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Total number of poetry lines: 1314\n" ] } ] }, { "cell_type": "code", "source": [ "# -------------------------\n", "# 4. Train/Test/Validation Split (80/10/10)\n", "# -------------------------\n", "train_texts, test_texts = train_test_split(texts, test_size=0.1, random_state=SEED)\n", "train_texts, val_texts = train_test_split(train_texts, test_size=0.1111, random_state=SEED)\n", "print(f\"Train samples: {len(train_texts)}\")\n", "print(f\"Validation samples: {len(val_texts)}\")\n", "print(f\"Test samples: {len(test_texts)}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_VvgUa3L-MAR", "outputId": "d045fd71-3f09-4d6c-eea9-34c3e444db59" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Train samples: 1050\n", "Validation samples: 132\n", "Test samples: 132\n" ] } ] }, { "cell_type": "code", "source": [ "# -------------------------\n", "# 5. Train a SentencePiece BPE Tokenizer\n", "# -------------------------\n", "all_texts_file = \"all_texts.txt\"\n", "if not os.path.exists(all_texts_file):\n", " with open(all_texts_file, \"w\", encoding=\"utf-8\") as f:\n", " for line in texts:\n", " f.write(line.strip() + \"\\n\")\n", "else:\n", " print(f\"{all_texts_file} already exists; skipping file creation.\")\n", "\n", "\n", "sp_model_prefix = \"urdu_sp\"\n", "model_file = f\"{sp_model_prefix}.model\"\n", "vocab_file = f\"{sp_model_prefix}.vocab\"\n", "\n", "vocab_size = 12000 # Adjust as needed\n", "model_type = \"bpe\"\n", "\n", "if not (os.path.exists(model_file) and os.path.exists(vocab_file)):\n", " print(\"SentencePiece model or vocab not found. Training...\")\n", " spm.SentencePieceTrainer.Train(\n", " f\"--input={all_texts_file} \"\n", " f\"--model_prefix={sp_model_prefix} \"\n", " f\"--vocab_size={vocab_size} \"\n", " f\"--model_type={model_type} \"\n", " \"--character_coverage=1.0 \"\n", " \"--pad_id=0 --unk_id=1 --bos_id=2 --eos_id=3\"\n", " )\n", "else:\n", " print(\"SentencePiece model & vocab found; skipping training.\")\n", "\n", "# Load the SentencePiece model\n", "sp = spm.SentencePieceProcessor()\n", "sp.load(model_file)\n", "print(\"Loaded SentencePiece model with vocab size:\", sp.get_piece_size())\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2L1JgC02-OBW", "outputId": "d6ea06cf-8f54-47d8-fada-a016ca1df4c9" }, "execution_count": 8, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Loaded SentencePiece model with vocab size: 12000\n" ] } ] }, { "cell_type": "code", "source": [ "# -------------------------\n", "# 6. Tokenize Data\n", "# -------------------------\n", "train_ids = [sp.encode_as_ids(t) for t in train_texts]\n", "val_ids = [sp.encode_as_ids(t) for t in val_texts]\n", "test_ids = [sp.encode_as_ids(t) for t in test_texts]" ], "metadata": { "id": "lq7lbUcu-RDU" }, "execution_count": 9, "outputs": [] }, { "cell_type": "code", "source": [ "# -------------------------\n", "# 7. Create Dataset and DataLoader\n", "# -------------------------\n", "class PoetryDataset(Dataset):\n", " def __init__(self, token_ids_list, max_length=250):\n", " self.data = token_ids_list\n", " self.max_length = max_length\n", "\n", " def __len__(self):\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " # Truncate tokens to max_length\n", " token_ids = self.data[idx][:self.max_length]\n", " # Create input by adding BOS token (2) at the beginning\n", " input_ids = [2] + token_ids\n", " # Create target by appending EOS token (3) at the end\n", " target_ids = token_ids + [3]\n", " return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)\n", "\n", "def collate_fn(batch):\n", " inputs, targets = zip(*batch)\n", " max_len = max(len(x) for x in inputs)\n", " padded_inputs = [torch.cat([x, torch.zeros(max_len - len(x), dtype=torch.long)]) for x in inputs]\n", " padded_targets = [torch.cat([t, torch.zeros(max_len - len(t), dtype=torch.long)]) for t in targets]\n", " return torch.stack(padded_inputs), torch.stack(padded_targets)" ], "metadata": { "id": "OZ9_kG0M-TOF" }, "execution_count": 10, "outputs": [] }, { "cell_type": "code", "source": [ "train_dataset = PoetryDataset(train_ids, max_length=250)\n", "val_dataset = PoetryDataset(val_ids, max_length=250)\n", "test_dataset = PoetryDataset(test_ids, max_length=250)\n", "\n", "batch_size = 64\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)\n", "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, drop_last=True)\n", "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, drop_last=True)" ], "metadata": { "id": "z1aGUj-w-Xh9" }, "execution_count": 11, "outputs": [] }, { "cell_type": "code", "source": [ "# -------------------------\n", "# 8. Define the BiLSTM Language Model\n", "# -------------------------\n", "class BiLSTMLanguageModel(nn.Module):\n", " def __init__(self, vocab_size, embed_dim=512, hidden_dim=768, num_layers=3, dropout=0.2):\n", " super(BiLSTMLanguageModel, self).__init__()\n", " self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)\n", " # Stacked Bi-LSTM layers\n", " self.lstm = nn.LSTM(\n", " input_size=embed_dim,\n", " hidden_size=hidden_dim,\n", " num_layers=num_layers,\n", " batch_first=True,\n", " bidirectional=True,\n", " dropout=dropout\n", " )\n", " # Linear layer to project LSTM outputs to vocabulary size\n", " self.fc = nn.Linear(hidden_dim * 2, vocab_size)\n", "\n", " def forward(self, x, hidden=None):\n", " emb = self.embed(x)\n", " out, hidden = self.lstm(emb, hidden)\n", " logits = self.fc(out)\n", " return logits, hidden" ], "metadata": { "id": "YD8F_0WM-apV" }, "execution_count": 12, "outputs": [] }, { "cell_type": "code", "source": [ "vocab_size = sp.get_piece_size()\n", "model = BiLSTMLanguageModel(vocab_size, embed_dim=512, hidden_dim=768, num_layers=3, dropout=0.2)\n", "model = model.to(device)" ], "metadata": { "id": "aKWTogmN-gaq" }, "execution_count": 13, "outputs": [] }, { "cell_type": "code", "source": [ "# -------------------------\n", "# 9. Training Setup (Loss, Optimizer, Scheduler)\n", "# -------------------------\n", "criterion = nn.CrossEntropyLoss(ignore_index=0)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)\n", "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)\n", "\n", "def evaluate(model, data_loader):\n", " model.eval()\n", " total_loss, total_tokens = 0, 0\n", " with torch.no_grad():\n", " for inputs, targets in data_loader:\n", " inputs = inputs.to(device)\n", " targets = targets.to(device)\n", " logits, _ = model(inputs)\n", " logits = logits.view(-1, vocab_size)\n", " targets = targets.view(-1)\n", " loss = criterion(logits, targets)\n", " total_loss += loss.item() * (targets != 0).sum().item()\n", " total_tokens += (targets != 0).sum().item()\n", " return total_loss / total_tokens" ], "metadata": { "id": "9W5USllq-i83" }, "execution_count": 14, "outputs": [] }, { "cell_type": "code", "source": [ "# -------------------------\n", "# 10. Training Loop with Testing Code and Weight Saving\n", "# -------------------------\n", "num_epochs = 10\n", "weights_path = \"model_weights.pth\"\n", "\n", "if not os.path.exists(weights_path):\n", " for epoch in range(num_epochs):\n", " model.train()\n", " total_loss, total_tokens = 0, 0\n", " for inputs, targets in train_loader:\n", " inputs = inputs.to(device)\n", " targets = targets.to(device)\n", " optimizer.zero_grad()\n", " logits, _ = model(inputs)\n", " logits = logits.view(-1, vocab_size)\n", " targets = targets.view(-1)\n", " loss = criterion(logits, targets)\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n", " optimizer.step()\n", " total_loss += loss.item() * (targets != 0).sum().item()\n", " total_tokens += (targets != 0).sum().item()\n", " train_loss = total_loss / total_tokens\n", " val_loss = evaluate(model, val_loader)\n", " scheduler.step()\n", " print(f\"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}\")\n", " test_loss = evaluate(model, test_loader)\n", " print(f\"Test Loss: {test_loss:.4f}\")\n", " torch.save(model.state_dict(), weights_path)\n", "else:\n", " print(\"Loading pre-trained model weights...\")\n", " model.load_state_dict(torch.load(weights_path, map_location=device))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "B0nDauKT-nQC", "outputId": "c082b8a8-70fb-4375-8b89-6deb72b31f6f" }, "execution_count": 15, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch [1/10], Train Loss: 7.1034, Val Loss: 6.2269\n", "Epoch [2/10], Train Loss: 5.7528, Val Loss: 5.4652\n", "Epoch [3/10], Train Loss: 5.0948, Val Loss: 4.9459\n", "Epoch [4/10], Train Loss: 4.4997, Val Loss: 4.2981\n", "Epoch [5/10], Train Loss: 3.9654, Val Loss: 3.9398\n", "Epoch [6/10], Train Loss: 3.6264, Val Loss: 3.6214\n", "Epoch [7/10], Train Loss: 3.3671, Val Loss: 3.4665\n", "Epoch [8/10], Train Loss: 3.2082, Val Loss: 3.3188\n", "Epoch [9/10], Train Loss: 3.0880, Val Loss: 3.2478\n", "Epoch [10/10], Train Loss: 3.0126, Val Loss: 3.1772\n", "Test Loss: 3.1696\n" ] } ] }, { "cell_type": "code", "source": [ "\n", "\n", "def generate_poetry_nucleus(model, sp, start_word, num_words=12, temperature=1.2, top_p=0.85):\n", " \"\"\"\n", " Generate a poetry sequence using nucleus (top-p) sampling.\n", " The output is formatted so that every 6 words appear on a new line.\n", " If num_words is specified, it means 1 starting word + (num_words - 1) generated tokens.\n", " \"\"\"\n", " model.eval()\n", " start_ids = sp.encode_as_ids(start_word)\n", " input_ids = [2] + start_ids # Insert BOS (token 2)\n", " input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)\n", " hidden = None\n", "\n", " with torch.no_grad():\n", " logits, hidden = model(input_tensor, hidden)\n", "\n", " generated_ids = input_ids[:] # Copy initial tokens\n", "\n", " for _ in range(num_words - 1): # Generate one less token\n", " # Get the logits of the last generated token\n", " last_logits = logits[:, -1, :] # Shape: (1, vocab_size)\n", " scaled_logits = last_logits / temperature\n", "\n", " # Sort the logits in descending order\n", " sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True)\n", " cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n", "\n", " # Filter out tokens with cumulative probability above top_p\n", " filtered_indices = cumulative_probs > top_p\n", " if torch.all(filtered_indices):\n", " filtered_indices[-1] = False # Ensure at least one token remains\n", " sorted_indices = sorted_indices[~filtered_indices]\n", " sorted_logits = sorted_logits[~filtered_indices]\n", "\n", " # Sample the next token from the filtered distribution\n", " if len(sorted_indices) > 0:\n", " next_token_id = sorted_indices[torch.multinomial(F.softmax(sorted_logits, dim=-1), 1).item()].item()\n", " else:\n", " next_token_id = torch.argmax(last_logits).item()\n", " generated_ids.append(next_token_id)\n", "\n", " # Prepare next input and update hidden state\n", " next_input = torch.tensor([[next_token_id]], dtype=torch.long, device=device)\n", " logits, hidden = model(next_input, hidden)\n", "\n", " # Decode generated tokens (skip BOS) and format output: 6 words per line\n", " generated_text = sp.decode_ids(generated_ids[1:])\n", " words = generated_text.split()\n", " formatted_text = \"\\n\".join([\" \".join(words[i:i+6]) for i in range(0, len(words), 6)])\n", " return formatted_text\n" ], "metadata": { "id": "kmsILzIh_0um" }, "execution_count": 16, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "\n", "# -------------------------\n", "# 12. Example Usage for Testing (Optional)\n", "# -------------------------\n", "if __name__ == \"__main__\":\n", " # Test the generation function in the notebook/script\n", " start_word = \"ishq\"\n", " print(\"Generated Poetry:\\n\", generate_poetry_nucleus(model, sp, start_word, num_words=12, temperature=1.2, top_p=0.85))\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "a3WKAKtJ_8YU", "outputId": "9571d2a7-97a4-4b1d-d106-3b7ccd0da43f" }, "execution_count": 18, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Generated Poetry:\n", " ishq nishan tum phir kar phir\n", "ik baat aur phir ye phir\n" ] } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "hK3-OgKI98Ia" }, "execution_count": 17, "outputs": [] } ] }