{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import concrete.ml\n",
    "import torch\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training: \n",
    "    1. Gather dataset of pictures\n",
    "    2. Preprocess the data\n",
    "    3. Find pretrained model \n",
    "    4. Segment Pretrained model into client-model and encrypted-server-model \n",
    "    5. Retrain the server-side model on 8 bits\n",
    "    6. Take output of the client model and truncate the floats to 8 bits\n",
    "\n",
    "Production\n",
    "    1. Take a picture :)\n",
    "    2. Evaluate client model on photo (clear)\n",
    "    3. Truncate to 8 bits\n",
    "    4. Encrypt \n",
    "    5. Send encrypted data to server\n",
    "    6. Send back encrypted result\n",
    "    7. decrypt result\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Step 1: Load Pretrained MobileNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import models\n",
    "\n",
    "# Load the pretrained MobileNet model\n",
    "mobilenet = models.mobilenet_v2(pretrained=True)\n",
    "\n",
    "# Set model to evaluation mode\n",
    "mobilenet.eval()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Step 2: Segment the Pretrained Model into Client and Server Parts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Client model - extracting up to the 10th layer (or any other cutoff)\n",
    "client_model = nn.Sequential(*list(mobilenet.features.children())[:10])\n",
    "\n",
    "# Server model - the remaining layers\n",
    "server_model = nn.Sequential(*list(mobilenet.features.children())[10:], mobilenet.classifier)\n",
    "\n",
    "# Freeze client model parameters (no need to retrain)\n",
    "for param in client_model.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Step 3: Quantize the Server-Side Model to 8 Bits\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.quantization import quantize_dynamic\n",
    "\n",
    "# Quantize the server model\n",
    "server_model_quantized = quantize_dynamic(\n",
    "    server_model,  # Model to be quantized\n",
    "    {nn.Linear},   # Layers to quantize (we quantize fully connected layers here)\n",
    "    dtype=torch.qint8  # Quantize to 8-bit\n",
    ")\n",
    "\n",
    "server_model_quantized.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Step 4: Truncate the Client Model Output to 8 Bits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def truncate_to_8_bits(tensor):\n",
    "    # Scale the tensor to the range [0, 255]\n",
    "    tensor = torch.clamp(tensor, min=0, max=1)\n",
    "    tensor = tensor * 255.0\n",
    "    tensor = tensor.to(torch.uint8)  # Convert to 8-bit integers\n",
    "    return tensor\n",
    "\n",
    "# Example input\n",
    "input_image = torch.randn(1, 3, 224, 224)  # A random image input\n",
    "\n",
    "# Client-side computation\n",
    "client_output = client_model(input_image)\n",
    "\n",
    "# Truncate the output to 8 bits\n",
    "client_output_8bit = truncate_to_8_bits(client_output)\n",
    "\n",
    "# The truncated output is now ready to be passed to the server\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Step 5: Server Model Inference on Quantized Data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure client output is in float format before feeding into server\n",
    "client_output_8bit = client_output_8bit.float() / 255.0  # Rescale to [0, 1]\n",
    "\n",
    "# Run inference on the server-side model\n",
    "server_output = server_model_quantized(client_output_8bit)\n",
    "\n",
    "# Output from the server model (class probabilities, etc.)\n",
    "print(server_output)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}