{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ddce00dd-9598-4a33-90fd-88cc22b85de4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n",
"```\n",
"pip install accelerate\n",
"```\n",
".\n",
"vae/diffusion_pytorch_model.safetensors not found\n",
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n",
"```\n",
"pip install accelerate\n",
"```\n",
".\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n",
"/environment/miniconda3/lib/python3.7/site-packages/transformers/models/clip/feature_extraction_clip.py:31: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.\n",
" FutureWarning,\n",
"/environment/miniconda3/lib/python3.7/site-packages/gradio/components.py:4424: UserWarning: The 'grid' parameter will be deprecated. Please use 'columns' instead.\n",
" \"The 'grid' parameter will be deprecated. Please use 'columns' instead.\",\n"
]
}
],
"source": [
"'''\n",
"from diffusers import utils\n",
"from diffusers.utils import deprecation_utils\n",
"from diffusers.models import cross_attention\n",
"utils.deprecate = lambda *arg, **kwargs: None\n",
"deprecation_utils.deprecate = lambda *arg, **kwargs: None\n",
"cross_attention.deprecate = lambda *arg, **kwargs: None\n",
"'''\n",
"\n",
"import os\n",
"import sys\n",
"'''\n",
"MAIN_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))\n",
"sys.path.insert(0, MAIN_DIR)\n",
"os.chdir(MAIN_DIR)\n",
"'''\n",
"\n",
"import gradio as gr\n",
"import numpy as np\n",
"import torch\n",
"import random\n",
"\n",
"from annotator.util import resize_image, HWC3\n",
"from annotator.canny import CannyDetector\n",
"from diffusers.models.unet_2d_condition import UNet2DConditionModel\n",
"from diffusers.pipelines import DiffusionPipeline\n",
"from diffusers.schedulers import DPMSolverMultistepScheduler\n",
"#from models import ControlLoRA, ControlLoRACrossAttnProcessor\n",
"\n",
"apply_canny = CannyDetector()\n",
"\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"'''\n",
"pipeline = DiffusionPipeline.from_pretrained(\n",
" 'IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1', safety_checker=None\n",
")\n",
"pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n",
"pipeline = pipeline.to(device)\n",
"unet: UNet2DConditionModel = pipeline.unet\n",
"\n",
"#ckpt_path = \"ckpts/sd-diffusiondb-canny-model-control-lora-zh\"\n",
"ckpt_path = \"svjack/canny-control-lora-zh\"\n",
"control_lora = ControlLoRA.from_pretrained(ckpt_path)\n",
"control_lora = control_lora.to(device)\n",
"\n",
"# load control lora attention processors\n",
"lora_attn_procs = {}\n",
"lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers])\n",
"n_ch = len(unet.config.block_out_channels)\n",
"control_ids = [i for i in range(n_ch)]\n",
"for name in pipeline.unet.attn_processors.keys():\n",
" cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n",
" if name.startswith(\"mid_block\"):\n",
" control_id = control_ids[-1]\n",
" elif name.startswith(\"up_blocks\"):\n",
" block_id = int(name[len(\"up_blocks.\")])\n",
" control_id = list(reversed(control_ids))[block_id]\n",
" elif name.startswith(\"down_blocks\"):\n",
" block_id = int(name[len(\"down_blocks.\")])\n",
" control_id = control_ids[block_id]\n",
"\n",
" lora_layers = lora_layers_list[control_id]\n",
" if len(lora_layers) != 0:\n",
" lora_layer: ControlLoRACrossAttnProcessor = lora_layers.pop(0)\n",
" lora_attn_procs[name] = lora_layer\n",
"\n",
"unet.set_attn_processor(lora_attn_procs)\n",
"'''\n",
"\n",
"from diffusers import (\n",
" AutoencoderKL,\n",
" ControlNetModel,\n",
" DDPMScheduler,\n",
" StableDiffusionControlNetPipeline,\n",
" UNet2DConditionModel,\n",
" UniPCMultistepScheduler,\n",
")\n",
"import torch\n",
"from diffusers.utils import load_image\n",
"\n",
"controlnet_model_name_or_path = \"svjack/ControlNet-Canny-Zh\"\n",
"controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path)\n",
"\n",
"base_model_path = \"IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1\"\n",
"pipe = StableDiffusionControlNetPipeline.from_pretrained(\n",
" base_model_path, controlnet=controlnet,\n",
" #torch_dtype=torch.float16\n",
")\n",
"\n",
"# speed up diffusion process with faster scheduler and memory optimization\n",
"pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n",
"#pipe.enable_model_cpu_offload()\n",
"if device == \"cuda\":\n",
" pipe = pipe.to(\"cuda\")\n",
"\n",
"\n",
"def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold):\n",
" from PIL import Image\n",
" with torch.no_grad():\n",
" img = resize_image(HWC3(input_image), image_resolution)\n",
" H, W, C = img.shape\n",
"\n",
" detected_map = apply_canny(img, low_threshold, high_threshold)\n",
" detected_map = HWC3(detected_map)\n",
" '''\n",
" print(type(detected_map))\n",
" return [detected_map]\n",
"\n",
" control = torch.from_numpy(detected_map[...,::-1].copy().transpose([2,0,1])).float().to(device)[None] / 127.5 - 1\n",
" _ = control_lora(control).control_states\n",
"\n",
" if seed == -1:\n",
" seed = random.randint(0, 65535)\n",
" '''\n",
" if seed == -1:\n",
" seed = random.randint(0, 65535)\n",
" control_image = Image.fromarray(detected_map)\n",
"\n",
" # run inference\n",
" generator = torch.Generator(device=device).manual_seed(seed)\n",
" images = []\n",
" for i in range(num_samples):\n",
" '''\n",
" _ = control_lora(control).control_states\n",
" image = pipeline(\n",
" prompt + ', ' + a_prompt, negative_prompt=n_prompt,\n",
" num_inference_steps=sample_steps, guidance_scale=scale, eta=eta,\n",
" generator=generator, height=H, width=W).images[0]\n",
" '''\n",
" image = pipe(\n",
" prompt + ', ' + a_prompt, negative_prompt=n_prompt,\n",
" num_inference_steps=sample_steps, guidance_scale=scale, eta=eta,\n",
" image = control_image,\n",
" generator=generator, height=H, width=W).images[0]\n",
" images.append(np.asarray(image))\n",
"\n",
" results = images\n",
" return [255 - detected_map] + results\n",
"\n",
"\n",
"block = gr.Blocks().queue()\n",
"with block:\n",
" with gr.Row():\n",
" gr.Markdown(\"## Control Stable Diffusion with Canny Edge Maps\")\n",
" #gr.Markdown(\"This _example_ was **drive** from [https://github.com/svjack/ControlLoRA-Chinese](https://github.com/svjack/ControlLoRA-Chinese)
\\n\")\n",
" with gr.Row():\n",
" with gr.Column():\n",
" input_image = gr.Image(source='upload', type=\"numpy\", value = \"love_in_rose.png\")\n",
" prompt = gr.Textbox(label=\"Prompt\", value = \"海边的俊俏美男子\")\n",
" run_button = gr.Button(label=\"Run\")\n",
" with gr.Accordion(\"Advanced options\", open=False):\n",
" num_samples = gr.Slider(label=\"Images\", minimum=1, maximum=12, value=1, step=1)\n",
" image_resolution = gr.Slider(label=\"Image Resolution\", minimum=256, maximum=768, value=512, step=256)\n",
" low_threshold = gr.Slider(label=\"Canny low threshold\", minimum=1, maximum=255, value=100, step=1)\n",
" high_threshold = gr.Slider(label=\"Canny high threshold\", minimum=1, maximum=255, value=200, step=1)\n",
" sample_steps = gr.Slider(label=\"Steps\", minimum=1, maximum=100, value=20, step=1)\n",
" scale = gr.Slider(label=\"Guidance Scale\", minimum=0.1, maximum=30.0, value=9.0, step=0.1)\n",
" seed = gr.Slider(label=\"Seed\", minimum=-1, maximum=2147483647, step=1, randomize=True)\n",
" eta = gr.Number(label=\"eta\", value=0.0)\n",
" a_prompt = gr.Textbox(label=\"Added Prompt\", value='详细的模拟混合媒体拼贴画,帆布质地的当代艺术风格,朋克艺术,逼真主义,感性的身体,表现主义,极简主义。杰作,完美的组成,逼真的美丽的脸')\n",
" n_prompt = gr.Textbox(label=\"Negative Prompt\",\n",
" value='低质量,模糊,混乱')\n",
" with gr.Column():\n",
" result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=\"gallery\").style(grid=2, height='auto')\n",
" ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold]\n",
" run_button.click(fn=process, inputs=ips, outputs=[result_gallery], show_progress = True)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6cc15198-a7d0-4949-be5a-5aed25b1b2aa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on local URL: http://172.16.54.41:7860\n",
"Running on public URL: https://6d03803b072cb36140.gradio.live\n",
"\n",
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces\n"
]
},
{
"data": {
"text/html": [
"