{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pytorch RoBERTa to ONNX"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook documents how to export the PyTorch NLP model into ONNX format and then use it to make predictions using the ONNX runtime.\n",
    "\n",
    "The model uses the `simpletransformers` library which is a Python wrappers around the `transformers` library which contains PyTorch NLP transformer architectures and weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from simpletransformers.model import TransformerModel\n",
    "from transformers import RobertaForSequenceClassification, RobertaTokenizer\n",
    "import onnx\n",
    "import onnxruntime"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Load pretrained PyTorch model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download the model weights from https://storage.googleapis.com/seldon-models/pytorch/moviesentiment_roberta/pytorch_model.bin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TransformerModel('roberta', 'roberta-base', args=({'fp16': False}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.model.load_state_dict(torch.load('pytorch_model.bin'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Export as ONNX"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PyTorch supports exporting to ONNX, you just need to specify a valid input tensor for the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = RobertaTokenizer.from_pretrained('roberta-base')\n",
    "input_ids = torch.tensor(tokenizer.encode(\"This film is so bad\", add_special_tokens=True)).unsqueeze(0)  # Batch size 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[   0,  713,  822,   16,   98, 1099,    2]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Export as ONNX, we specify dynamic axes for batch dimension and sequence length as sentences come in various lengths."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/janis/.conda/envs/py37/lib/python3.7/site-packages/transformers/modeling_roberta.py:172: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if input_ids[:, 0].sum().item() != 0:\n"
     ]
    }
   ],
   "source": [
    "torch.onnx.export(model.model,\n",
    "                  (input_ids),\n",
    "                  \"roberta.onnx\",\n",
    "                  input_names=['input'],\n",
    "                  output_names=['output'],\n",
    "                  dynamic_axes={'input' :{0 : 'batch_size',\n",
    "                                          1: 'sentence_length'},\n",
    "                                'output': {0: 'batch_size'}})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Test predictions are the same using ONNX runtime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model = onnx.load(\"roberta.onnx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# checks the exported model, may crash ipython kernel if run together with the PyTorch model in memory\n",
    "# onnx.checker.check_model(onnx_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnxruntime\n",
    "\n",
    "ort_session = onnxruntime.InferenceSession(\"roberta.onnx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_numpy(tensor):\n",
    "    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = torch.tensor(tokenizer.encode(\"This film is so bad\", add_special_tokens=True)).unsqueeze(0)  # Batch size 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute ONNX Runtime output prediction\n",
    "ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_ids)}\n",
    "ort_out = ort_session.run(None, ort_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model.model(input_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((tensor([[ 2.3067, -2.6440]], grad_fn=<AddmmBackward>),),\n",
       " [array([[ 2.3066945, -2.6439788]], dtype=float32)])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out, ort_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(to_numpy(out[0]), ort_out[0], rtol=1e-03, atol=1e-05)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}