{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "06eef311", "metadata": { "execution": { "iopub.execute_input": "2024-04-01T01:04:34.377065Z", "iopub.status.busy": "2024-04-01T01:04:34.376156Z", "iopub.status.idle": "2024-04-01T01:04:41.159287Z", "shell.execute_reply": "2024-04-01T01:04:41.158448Z" }, "papermill": { "duration": 6.79219, "end_time": "2024-04-01T01:04:41.161646", "exception": false, "start_time": "2024-04-01T01:04:34.369456", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torchvision import transforms\n", "from torch.utils.data import DataLoader\n", "from tqdm.auto import tqdm\n", "\n", "from model import MangaColorizer\n", "from utils import ImageDataset, adjust_output_shape" ] }, { "cell_type": "markdown", "id": "5e7ff784", "metadata": { "papermill": { "duration": 0.004403, "end_time": "2024-04-01T01:04:41.171084", "exception": false, "start_time": "2024-04-01T01:04:41.166681", "status": "completed" }, "tags": [] }, "source": [ "## Model architecture" ] }, { "cell_type": "code", "execution_count": 3, "id": "87d03ce6", "metadata": { "execution": { "iopub.execute_input": "2024-04-01T01:04:41.182184Z", "iopub.status.busy": "2024-04-01T01:04:41.181258Z", "iopub.status.idle": "2024-04-01T01:04:41.190651Z", "shell.execute_reply": "2024-04-01T01:04:41.189724Z" }, "papermill": { "duration": 0.017191, "end_time": "2024-04-01T01:04:41.192743", "exception": false, "start_time": "2024-04-01T01:04:41.175552", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MangaColorizer(\n", " (encoder): Sequential(\n", " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " (4): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " (5): ReLU(inplace=True)\n", " )\n", " (decoder): Sequential(\n", " (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " (4): ConvTranspose2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (5): Tanh()\n", " )\n", ")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = MangaColorizer()\n", "model" ] }, { "cell_type": "markdown", "id": "c4b5ff4a", "metadata": { "papermill": { "duration": 0.004206, "end_time": "2024-04-01T01:04:41.201565", "exception": false, "start_time": "2024-04-01T01:04:41.197359", "status": "completed" }, "tags": [] }, "source": [ "## Loading the Data" ] }, { "cell_type": "code", "execution_count": 5, "id": "42198e39", "metadata": { "execution": { "iopub.execute_input": "2024-04-01T01:04:41.247525Z", "iopub.status.busy": "2024-04-01T01:04:41.247244Z", "iopub.status.idle": "2024-04-01T01:04:41.627306Z", "shell.execute_reply": "2024-04-01T01:04:41.626292Z" }, "papermill": { "duration": 0.387773, "end_time": "2024-04-01T01:04:41.629778", "exception": false, "start_time": "2024-04-01T01:04:41.242005", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "transform = transforms.Compose([\n", " transforms.ToTensor()\n", "])\n", "\n", "train_dataset = ImageDataset(dir=\"/kaggle/input/manga-panels-colored/data/train\", \n", " transform=transform)\n", "validation_dataset = ImageDataset(dir=\"/kaggle/input/manga-panels-colored/data/validation\", \n", " transform=transform)\n", "test_dataset = ImageDataset(dir=\"/kaggle/input/manga-panels-colored/data/test\", \n", " transform=transform)" ] }, { "cell_type": "code", "execution_count": 6, "id": "8e5ea6dd", "metadata": { "execution": { "iopub.execute_input": "2024-04-01T01:04:41.640890Z", "iopub.status.busy": "2024-04-01T01:04:41.640098Z", "iopub.status.idle": "2024-04-01T01:04:41.645385Z", "shell.execute_reply": "2024-04-01T01:04:41.644485Z" }, "papermill": { "duration": 0.012881, "end_time": "2024-04-01T01:04:41.647392", "exception": false, "start_time": "2024-04-01T01:04:41.634511", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)\n", "validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=True)\n", "test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)" ] }, { "cell_type": "markdown", "id": "fd0bbc4c", "metadata": { "papermill": { "duration": 0.004236, "end_time": "2024-04-01T01:04:41.656333", "exception": false, "start_time": "2024-04-01T01:04:41.652097", "status": "completed" }, "tags": [] }, "source": [ "## Training the model" ] }, { "cell_type": "code", "execution_count": 8, "id": "6bb853cd", "metadata": { "execution": { "iopub.execute_input": "2024-04-01T01:04:41.683769Z", "iopub.status.busy": "2024-04-01T01:04:41.683460Z", "iopub.status.idle": "2024-04-01T01:04:41.721713Z", "shell.execute_reply": "2024-04-01T01:04:41.720922Z" }, "papermill": { "duration": 0.04614, "end_time": "2024-04-01T01:04:41.724036", "exception": false, "start_time": "2024-04-01T01:04:41.677896", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "criterion = nn.MSELoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.0001)" ] }, { "cell_type": "code", "execution_count": null, "id": "7b70952d", "metadata": { "execution": { "iopub.execute_input": "2024-04-01T01:04:41.735281Z", "iopub.status.busy": "2024-04-01T01:04:41.734495Z", "iopub.status.idle": "2024-04-01T01:04:41.955794Z", "shell.execute_reply": "2024-04-01T01:04:41.954865Z" }, "papermill": { "duration": 0.229072, "end_time": "2024-04-01T01:04:41.957864", "exception": false, "start_time": "2024-04-01T01:04:41.728792", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "4d252b1e", "metadata": { "execution": { "iopub.execute_input": "2024-04-01T01:04:41.969269Z", "iopub.status.busy": "2024-04-01T01:04:41.968485Z", "iopub.status.idle": "2024-04-01T10:06:12.760575Z", "shell.execute_reply": "2024-04-01T10:06:12.759664Z" }, "papermill": { "duration": 32490.811819, "end_time": "2024-04-01T10:06:12.774567", "exception": false, "start_time": "2024-04-01T01:04:41.962748", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "num_epochs = 100\n", "num_training_steps = num_epochs * len(train_loader)\n", "progress_bar = tqdm(range(num_training_steps))\n", "\n", "train_losses = []\n", "valid_losses = []\n", "\n", "best_valid_loss = float(\"inf\")\n", "epochs_no_improve = 0\n", "patience = 10\n", "best_model = None\n", "\n", "for epoch in range(num_epochs):\n", " model.train()\n", " train_loss = 0.0\n", " for images, targets in train_loader:\n", " images = images.to(device)\n", " targets = targets.to(device)\n", " outputs = model(images)\n", " try:\n", " loss = criterion(outputs, targets)\n", " except RuntimeError:\n", " adjusted_output = adjust_output_shape(outputs, targets)\n", " loss = criterion(adjusted_output, targets)\n", " loss.backward()\n", "\n", " optimizer.step()\n", " optimizer.zero_grad()\n", " progress_bar.update(1)\n", "\n", " train_loss += loss.item()\n", " \n", " train_losses.append(train_loss / len(train_loader))\n", "\n", " model.eval()\n", " valid_loss = 0.0\n", " with torch.no_grad():\n", " for images, targets in validation_loader:\n", " images = images.to(device)\n", " targets = targets.to(device)\n", " outputs = model(images)\n", " try:\n", " loss = criterion(outputs, targets)\n", " except RuntimeError:\n", " adjusted_output = adjust_output_shape(outputs, targets)\n", " loss = criterion(adjusted_output, targets)\n", " valid_loss += loss.item()\n", " \n", " valid_loss /= len(validation_loader)\n", " valid_losses.append(valid_loss)\n", "\n", " print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Valid Loss: {valid_loss:.4f}') \n", " torch.save(model.state_dict(), \"last_checkpoint.pth\")\n", "\n", " if valid_loss < best_valid_loss:\n", " best_valid_loss = valid_loss\n", " epochs_no_improve = 0\n", " best_model = model.state_dict()\n", " torch.save(best_model, \"best_model_checkpoint.pth\")\n", " else:\n", " epochs_no_improve += 1\n", " if epochs_no_improve == patience:\n", " print(f\"Early stopping after {epoch+1} epochs with no improvement.\")\n", " break\n", "\n", "model.load_state_dict(best_model)" ] }, { "cell_type": "code", "execution_count": 11, "id": "e3d447f1", "metadata": { "execution": { "iopub.execute_input": "2024-04-01T10:06:12.800929Z", "iopub.status.busy": "2024-04-01T10:06:12.800630Z", "iopub.status.idle": "2024-04-01T10:06:13.054441Z", "shell.execute_reply": "2024-04-01T10:06:13.053544Z" }, "papermill": { "duration": 0.269281, "end_time": "2024-04-01T10:06:13.056431", "exception": false, "start_time": "2024-04-01T10:06:12.787150", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "epochs = np.arange(1, len(train_losses) + 1)\n", "plt.figure(figsize=(10, 5))\n", "\n", "plt.plot(epochs, train_losses, label='Train Loss')\n", "plt.plot(epochs, valid_losses, label='Valid Loss')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.title('Training and Validation Loss')\n", "plt.legend()\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "id": "c3b04bf7", "metadata": { "execution": { "iopub.execute_input": "2024-04-01T10:06:13.084256Z", "iopub.status.busy": "2024-04-01T10:06:13.083932Z", "iopub.status.idle": "2024-04-01T10:06:38.988309Z", "shell.execute_reply": "2024-04-01T10:06:38.987316Z" }, "papermill": { "duration": 25.920465, "end_time": "2024-04-01T10:06:38.990405", "exception": false, "start_time": "2024-04-01T10:06:13.069940", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The loss on the test set is 0.008598181701702099\n" ] } ], "source": [ "model.eval()\n", "test_loss = 0.0\n", "with torch.no_grad():\n", " for images, targets in test_loader:\n", " images = images.to(device)\n", " targets = targets.to(device)\n", " outputs = model(images)\n", " try:\n", " loss = criterion(outputs, targets)\n", " except RuntimeError:\n", " adjusted_output = adjust_output_shape(outputs, targets)\n", " loss = criterion(adjusted_output, targets)\n", " test_loss += loss.item()\n", " test_loss /= len(test_loader)\n", "print(f\"The loss on the test set is {test_loss}\")" ] } ], "metadata": { "kaggle": { "accelerator": "gpu", "dataSources": [ { "datasetId": 4705836, "sourceId": 7993213, "sourceType": "datasetVersion" } ], "dockerImageVersionId": 30674, "isGpuEnabled": true, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" }, "papermill": { "default_parameters": {}, "duration": 32529.008488, "end_time": "2024-04-01T10:06:40.496191", "environment_variables": {}, "exception": null, "input_path": "__notebook__.ipynb", "output_path": "__notebook__.ipynb", "parameters": {}, "start_time": "2024-04-01T01:04:31.487703", "version": "2.5.0" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": { "0d705fdb53a0460cb06e86e2212618f5": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "1a8db5a1fed14afca913a6edb7794b17": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "1e977edae4d54b5fbd5b7018ffb9858f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "38d95f1d4d65453895e3b9f5ea41723c": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "50e7ac5f3f4b4fe58e2e110fea403bbc": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_66e2c724e5e54c878bb614d51a452b26", "IPY_MODEL_6901f2bab73b4155b0f53c30467d46c3", "IPY_MODEL_771460c15f794c71af4f6eeb0ea7b1ad" ], "layout": "IPY_MODEL_1e977edae4d54b5fbd5b7018ffb9858f" } }, "66e2c724e5e54c878bb614d51a452b26": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_38d95f1d4d65453895e3b9f5ea41723c", "placeholder": "​", "style": "IPY_MODEL_0d705fdb53a0460cb06e86e2212618f5", "value": "100%" } }, "6901f2bab73b4155b0f53c30467d46c3": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_dd05d612550a4db28ebf2c7cfc2312fe", "max": 75500, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_1a8db5a1fed14afca913a6edb7794b17", "value": 75500 } }, "771460c15f794c71af4f6eeb0ea7b1ad": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_d08ee627a3cc414489e61161d8a917ae", "placeholder": "​", "style": "IPY_MODEL_b2818288e9b9459fb75a1ea2e6f35117", "value": " 75500/75500 [9:00:45<00:00, 2.32it/s]" } }, "b2818288e9b9459fb75a1ea2e6f35117": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "d08ee627a3cc414489e61161d8a917ae": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "dd05d612550a4db28ebf2c7cfc2312fe": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } } }, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }