{ "cells": [ { "cell_type": "markdown", "source": [ "## Setup\n", "\n", "Execute the cell below to setup the notebook to run Custom diffusion" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "!bash setup.sh " ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Training\n", "\n", "## Diffusers method" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "## launch training script (2 GPUs recommended, increase --max_train_steps to 500 if 1 GPU)\n", "\n", "!accelerate launch src/diffuser_training.py \\\n", " --pretrained_model_name_or_path=compvis/stable-diffusion-v1-4 \\\n", " --instance_data_dir=./data/cat \\\n", " --class_data_dir=./real_reg/samples_cat/ \\\n", " --output_dir=./logs/cat \\\n", " --with_prior_preservation --real_prior --prior_loss_weight=1.0 \\\n", " --instance_prompt=\"photo of a cat\" \\\n", " --class_prompt=\"cat\" \\\n", " --resolution=512 \\\n", " --train_batch_size=2 \\\n", " --learning_rate=1e-5 \\\n", " --lr_warmup_steps=0 \\\n", " --max_train_steps=250 \\\n", " --num_class_images=200 \\\n", " --scale_lr \\\n", " --modifier_token \"\"\n", "\n", "## sample \n", "python src/sample_diffuser.py --delta_ckpt logs/cat/delta.bin --ckpt \"CompVis/stable-diffusion-v1-4\" --prompt \" cat playing with a ball\"\n" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "## launch training script (2 GPUs recommended, increase --max_train_steps to 1000 if 1 GPU)\n", "## provide some json file with the info about each concept\n", "!CUDA_VISIBLE_DEVICES=2,3 accelerate launch src/diffuser_training.py \\\n", " --pretrained_model_name_or_path=compvis/stable-diffusion-v1-4 \\\n", " --output_dir=./logs/cat_wooden_pot \\\n", " --concepts_list=./assets/concept_list.json \\\n", " --with_prior_preservation --real_prior --prior_loss_weight=1.0 \\\n", " --resolution=512 \\\n", " --train_batch_size=2 \\\n", " --learning_rate=1e-5 \\\n", " --lr_warmup_steps=0 \\\n", " --max_train_steps=500 \\\n", " --num_class_images=200 \\\n", " --scale_lr --hflip \\\n", " --modifier_token \"+\" \n", "\n", "## sample \n", "python src/sample_diffuser.py --delta_ckpt logs/cat_wooden_pot/delta.bin --ckpt \"CompVis/stable-diffusion-v1-4\" --prompt \" cat sitting inside a wooden pot and looking up\"" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Gradio application" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "%cd ~/../custom-diffusion\n", "!python app.py --share" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Sample from the newly trained model" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "import torch\n", "from torch import autocast\n", "from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler\n", "\n", "%cd ~/../notebooks\n", "\n", "## Set the args\n", "device = 'cuda'\n", "model_ = 'custom-diffusion/results/checkpoint-500'\n", "size_ = 512\n", "precision = 512\n", "sample_num = 5\n", "\n", "print(f'Generating samples from Stable Diffusion {model_} checkpoint ({precision})')\n", "\n", "\n", "## Instantiate the model pipe with StableDiffusionPipeline\n", "pipe = StableDiffusionPipeline.from_pretrained(model_, torch_dtype=torch.float16, revision=\"fp16\")\n", "pipe = pipe.to(device)\n", "\n", "for j in range(sample_num):\n", " a = pipe(prompt = 'a photo of a cat',\n", " negative_prompt = None,\n", " guidance_scale=7.5,\n", " height = 512,\n", " width = 512,\n", " num_images_per_prompt = 1,\n", " num_inference_steps=50)['images'] \n", " for i in a:\n", " display(i)\n", " # hash the next line out to save results\n", " #a.save(f'outputs/gen-image-{j}.png')\n" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Compress the model" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "!python src/compress.py --delta_ckpt --ckpt " ], "outputs": [], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }