{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/soumik12345/enhance-me/blob/mirnet/notebooks/enhance_me_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1JryaVhtBHij",
    "outputId": "97ee6a4a-2479-4124-e96a-f0a792bdec46"
   },
   "outputs": [],
   "source": [
    "!git clone https://github.com/soumik12345/enhance-me\n",
    "!pip install -qqq wandb streamlit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "G_c4VtXWHR5l"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "from PIL import Image\n",
    "from enhance_me import commons\n",
    "from enhance_me.mirnet import MIRNet\n",
    "from enhance_me.zero_dce import ZeroDCE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZpBHbYaMIqP_"
   },
   "outputs": [],
   "source": [
    "# @title MIRNet Train Configs\n",
    "\n",
    "experiment_name = \"lol_dataset_256\"  # @param {type:\"string\"}\n",
    "image_size = 128  # @param {type:\"integer\"}\n",
    "dataset_label = \"lol\"  # @param [\"lol\"]\n",
    "apply_random_horizontal_flip = True  # @param {type:\"boolean\"}\n",
    "apply_random_vertical_flip = True  # @param {type:\"boolean\"}\n",
    "apply_random_rotation = True  # @param {type:\"boolean\"}\n",
    "use_mixed_precision = True  # @param {type:\"boolean\"}\n",
    "wandb_api_key = \"\"  # @param {type:\"string\"}\n",
    "val_split = 0.1  # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
    "batch_size = 4  # @param {type:\"integer\"}\n",
    "num_recursive_residual_groups = 3  # @param {type:\"slider\", min:1, max:5, step:1}\n",
    "num_multi_scale_residual_blocks = 2  # @param {type:\"slider\", min:1, max:5, step:1}\n",
    "learning_rate = 1e-4  # @param {type:\"number\"}\n",
    "epsilon = 1e-3  # @param {type:\"number\"}\n",
    "epochs = 50  # @param {type:\"slider\", min:10, max:100, step:5}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "id": "IVRoedqBIMuH",
    "outputId": "53ca5beb-871a-4ec3-b757-173e09a15331"
   },
   "outputs": [],
   "source": [
    "mirnet = MIRNet(\n",
    "    experiment_name=experiment_name,\n",
    "    wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "O66Iwzx8IsGh",
    "outputId": "0b6f1683-65d1-4737-a32f-d36b331d2bc2"
   },
   "outputs": [],
   "source": [
    "mirnet.build_datasets(\n",
    "    image_size=image_size,\n",
    "    dataset_label=dataset_label,\n",
    "    apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
    "    apply_random_vertical_flip=apply_random_vertical_flip,\n",
    "    apply_random_rotation=apply_random_rotation,\n",
    "    val_split=val_split,\n",
    "    batch_size=batch_size,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tsfKrBCsL_Bb"
   },
   "outputs": [],
   "source": [
    "mirnet.build_model(\n",
    "    use_mixed_precision=use_mixed_precision,\n",
    "    num_recursive_residual_groups=num_recursive_residual_groups,\n",
    "    num_multi_scale_residual_blocks=num_multi_scale_residual_blocks,\n",
    "    learning_rate=learning_rate,\n",
    "    epsilon=epsilon,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "y3L9wlpkNziL",
    "outputId": "5149f0e7-91f4-450f-c43a-1b6028692bbc"
   },
   "outputs": [],
   "source": [
    "history = mirnet.train(epochs=epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mirnet.load_weights(os.path.join(mirnet.experiment_name, \"weights.h5\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "daFKbgBkiyzc"
   },
   "outputs": [],
   "source": [
    "for index, low_image_file in enumerate(mirnet.test_low_images):\n",
    "    original_image = Image.open(low_image_file)\n",
    "    enhanced_image = mirnet.infer(original_image)\n",
    "    ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
    "    commons.plot_results(\n",
    "        [original_image, ground_truth, enhanced_image],\n",
    "        [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
    "        (18, 18),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dO-IbNQHkB3R"
   },
   "outputs": [],
   "source": [
    "# @title Zero-DCE Train Configs\n",
    "\n",
    "experiment_name = \"unpaired_low_light_256_resize\"  # @param {type:\"string\"}\n",
    "image_size = 256  # @param {type:\"integer\"}\n",
    "dataset_label = \"unpaired\"  # @param [\"lol\", \"unpaired\"]\n",
    "use_mixed_precision = False  # @param {type:\"boolean\"}\n",
    "apply_resize = True  # @param {type:\"boolean\"}\n",
    "apply_random_horizontal_flip = True  # @param {type:\"boolean\"}\n",
    "apply_random_vertical_flip = True  # @param {type:\"boolean\"}\n",
    "apply_random_rotation = True  # @param {type:\"boolean\"}\n",
    "wandb_api_key = \"\"  # @param {type:\"string\"}\n",
    "val_split = 0.1  # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
    "batch_size = 16  # @param {type:\"integer\"}\n",
    "learning_rate = 1e-4  # @param {type:\"number\"}\n",
    "epsilon = 1e-3  # @param {type:\"number\"}\n",
    "epochs = 100  # @param {type:\"slider\", min:10, max:100, step:5}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_dce = ZeroDCE(\n",
    "    experiment_name=experiment_name,\n",
    "    wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key,\n",
    "    use_mixed_precision=use_mixed_precision\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_dce.build_datasets(\n",
    "    image_size=image_size,\n",
    "    dataset_label=dataset_label,\n",
    "    apply_resize=apply_resize,\n",
    "    apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
    "    apply_random_vertical_flip=apply_random_vertical_flip,\n",
    "    apply_random_rotation=apply_random_rotation,\n",
    "    val_split=val_split,\n",
    "    batch_size=batch_size\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "zero_dce.compile(learning_rate=learning_rate)\n",
    "history = zero_dce.train(epochs=epochs)\n",
    "zero_dce.save_weights(os.path.join(experiment_name, \"weights.h5\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for index, low_image_file in enumerate(zero_dce.test_low_images):\n",
    "    original_image = Image.open(low_image_file)\n",
    "    enhanced_image = zero_dce.infer(original_image)\n",
    "    commons.plot_results(\n",
    "        [original_image, enhanced_image],\n",
    "        [\"Original Image\", \"Enhanced Image\"],\n",
    "        (18, 18),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyN4LuJh6kWhbqxzA5s9sp7k",
   "collapsed_sections": [],
   "include_colab_link": true,
   "machine_shape": "hm",
   "name": "enhance-me-train.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}