{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this line in Colab to install the package if it is\n",
    "# not already installed.\n",
    "!pip install git+https://github.com/openai/glide-text2im"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "from IPython.display import display\n",
    "import torch as th\n",
    "import torch.nn as nn\n",
    "\n",
    "from glide_text2im.clip.model_creation import create_clip_model\n",
    "from glide_text2im.download import load_checkpoint\n",
    "from glide_text2im.model_creation import (\n",
    "    create_model_and_diffusion,\n",
    "    model_and_diffusion_defaults,\n",
    "    model_and_diffusion_defaults_upsampler,\n",
    ")\n",
    "from glide_text2im.tokenizer.simple_tokenizer import SimpleTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This notebook supports both CPU and GPU.\n",
    "# On CPU, generating one sample may take on the order of 20 minutes.\n",
    "# On a GPU, it should be under a minute.\n",
    "\n",
    "has_cuda = th.cuda.is_available()\n",
    "device = th.device('cpu' if not has_cuda else 'cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create base model.\n",
    "options = model_and_diffusion_defaults()\n",
    "options['use_fp16'] = has_cuda\n",
    "options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n",
    "model, diffusion = create_model_and_diffusion(**options)\n",
    "model.eval()\n",
    "if has_cuda:\n",
    "    model.convert_to_fp16()\n",
    "model.to(device)\n",
    "model.load_state_dict(load_checkpoint('base', device))\n",
    "print('total base parameters', sum(x.numel() for x in model.parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create upsampler model.\n",
    "options_up = model_and_diffusion_defaults_upsampler()\n",
    "options_up['use_fp16'] = has_cuda\n",
    "options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling\n",
    "model_up, diffusion_up = create_model_and_diffusion(**options_up)\n",
    "model_up.eval()\n",
    "if has_cuda:\n",
    "    model_up.convert_to_fp16()\n",
    "model_up.to(device)\n",
    "model_up.load_state_dict(load_checkpoint('upsample', device))\n",
    "print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create CLIP model.\n",
    "clip_model = create_clip_model(device=device)\n",
    "clip_model.image_encoder.load_state_dict(load_checkpoint('clip/image-enc', device))\n",
    "clip_model.text_encoder.load_state_dict(load_checkpoint('clip/text-enc', device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_images(batch: th.Tensor):\n",
    "    \"\"\" Display a batch of images inline. \"\"\"\n",
    "    scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n",
    "    reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n",
    "    display(Image.fromarray(reshaped.numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sampling parameters\n",
    "prompt = \"an oil painting of a corgi\"\n",
    "batch_size = 1\n",
    "guidance_scale = 3.0\n",
    "\n",
    "# Tune this parameter to control the sharpness of 256x256 images.\n",
    "# A value of 1.0 is sharper, but sometimes results in grainy artifacts.\n",
    "upsample_temp = 0.997"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##############################\n",
    "# Sample from the base model #\n",
    "##############################\n",
    "\n",
    "# Create the text tokens to feed to the model.\n",
    "tokens = model.tokenizer.encode(prompt)\n",
    "tokens, mask = model.tokenizer.padded_tokens_and_mask(\n",
    "    tokens, options['text_ctx']\n",
    ")\n",
    "\n",
    "# Pack the tokens together into model kwargs.\n",
    "model_kwargs = dict(\n",
    "    tokens=th.tensor([tokens] * batch_size, device=device),\n",
    "    mask=th.tensor([mask] * batch_size, dtype=th.bool, device=device),\n",
    ")\n",
    "\n",
    "# Setup guidance function for CLIP model.\n",
    "cond_fn = clip_model.cond_fn([prompt] * batch_size, guidance_scale)\n",
    "\n",
    "# Sample from the base model.\n",
    "model.del_cache()\n",
    "samples = diffusion.p_sample_loop(\n",
    "    model,\n",
    "    (batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n",
    "    device=device,\n",
    "    clip_denoised=True,\n",
    "    progress=True,\n",
    "    model_kwargs=model_kwargs,\n",
    "    cond_fn=cond_fn,\n",
    ")\n",
    "model.del_cache()\n",
    "\n",
    "# Show the output\n",
    "show_images(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##############################\n",
    "# Upsample the 64x64 samples #\n",
    "##############################\n",
    "\n",
    "tokens = model_up.tokenizer.encode(prompt)\n",
    "tokens, mask = model_up.tokenizer.padded_tokens_and_mask(\n",
    "    tokens, options_up['text_ctx']\n",
    ")\n",
    "\n",
    "# Create the model conditioning dict.\n",
    "model_kwargs = dict(\n",
    "    # Low-res image to upsample.\n",
    "    low_res=((samples+1)*127.5).round()/127.5 - 1,\n",
    "\n",
    "    # Text tokens\n",
    "    tokens=th.tensor(\n",
    "        [tokens] * batch_size, device=device\n",
    "    ),\n",
    "    mask=th.tensor(\n",
    "        [mask] * batch_size,\n",
    "        dtype=th.bool,\n",
    "        device=device,\n",
    "    ),\n",
    ")\n",
    "\n",
    "# Sample from the base model.\n",
    "model_up.del_cache()\n",
    "up_shape = (batch_size, 3, options_up[\"image_size\"], options_up[\"image_size\"])\n",
    "up_samples = diffusion_up.ddim_sample_loop(\n",
    "    model_up,\n",
    "    up_shape,\n",
    "    noise=th.randn(up_shape, device=device) * upsample_temp,\n",
    "    device=device,\n",
    "    clip_denoised=True,\n",
    "    progress=True,\n",
    "    model_kwargs=model_kwargs,\n",
    "    cond_fn=None,\n",
    ")[:batch_size]\n",
    "model_up.del_cache()\n",
    "\n",
    "# Show the output\n",
    "show_images(up_samples)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "e7d6e62d90e7e85f9a0faa7f0b1d576302d7ae6108e9fe361594f8e1c8b05781"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "accelerator": "GPU"
 },
 "nbformat": 4,
 "nbformat_minor": 2
}