Spaces:
Runtime error
Runtime error
| """ | |
| Train a diffusion model on images. | |
| """ | |
| import gradio as gr | |
| import argparse | |
| from einops import rearrange | |
| #from glide_text2im import dist_util, logger | |
| from torchvision.utils import make_grid | |
| from glide_text2im.script_util import ( | |
| model_and_diffusion_defaults, | |
| create_model_and_diffusion, | |
| args_to_dict, | |
| add_dict_to_argparser, | |
| ) | |
| from glide_text2im.image_datasets_sketch import get_tensor | |
| from glide_text2im.train_util import TrainLoop | |
| from glide_text2im.glide_util import sample | |
| import torch | |
| import os | |
| import torch as th | |
| import torchvision.utils as tvu | |
| import torch.distributed as dist | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100): | |
| parser, parser_up = create_argparser() | |
| args = parser.parse_args() | |
| args_up = parser_up.parse_args() | |
| #dist_util.setup_dist() | |
| if mode == 'sketch': | |
| args.mode = 'coco-edge' | |
| args_up.mode = 'coco-edge' | |
| args.model_path = hf_hub_download(repo_id="tfwang/PITI", filename="base.pt") | |
| args.sr_model_path = hf_hub_download(repo_id="tfwang/PITI", filename="upsample.pt") | |
| elif mode == 'mask': | |
| args.mode = 'coco' | |
| args_up.mode = 'coco' | |
| args.model_path = hf_hub_download(repo_id="tfwang/PITI", filename="base_mask.pt") | |
| args.sr_model_path = hf_hub_download(repo_id="tfwang/PITI", filename="upsample_mask.pt") | |
| args.val_data_dir = image | |
| args.sample_c = sample_c | |
| args.num_samples = num_samples | |
| options=args_to_dict(args, model_and_diffusion_defaults(0.).keys()) | |
| model, diffusion = create_model_and_diffusion(**options) | |
| options_up=args_to_dict(args_up, model_and_diffusion_defaults(True).keys()) | |
| model_up, diffusion_up = create_model_and_diffusion(**options_up) | |
| if args.model_path: | |
| print('loading model') | |
| model_ckpt = torch.load(args.model_path, map_location="cpu") | |
| model.load_state_dict( | |
| model_ckpt , strict=True ) | |
| if args.sr_model_path: | |
| print('loading sr model') | |
| model_ckpt2 = torch.load(args.sr_model_path, map_location="cpu") | |
| model_up.load_state_dict( | |
| model_ckpt2 , strict=True ) | |
| model.cuda() | |
| model_up.cuda() | |
| model.eval() | |
| model_up.eval() | |
| ########### dataset | |
| # logger.log("creating data loader...") | |
| if args.mode == 'coco': | |
| pil_image = image | |
| label_pil = pil_image.convert("RGB").resize((256, 256), Image.NEAREST) | |
| label_tensor = get_tensor()(label_pil) | |
| data_dict = {"ref":label_tensor.unsqueeze(0).repeat(args.num_samples, 1, 1, 1)} | |
| elif args.mode == 'coco-edge': | |
| # pil_image = Image.open(image) | |
| pil_image = image | |
| label_pil = pil_image.convert("L").resize((256, 256), Image.NEAREST) | |
| im_dist = cv2.distanceTransform(255-np.array(label_pil), cv2.DIST_L1, 3) | |
| im_dist = np.clip((im_dist) , 0, 255).astype(np.uint8) | |
| im_dist = Image.fromarray(im_dist).convert("RGB") | |
| label_tensor = get_tensor()(im_dist)[:1] | |
| data_dict = {"ref":label_tensor.unsqueeze(0).repeat(args.num_samples, 1, 1, 1)} | |
| print("sampling...") | |
| sampled_imgs = [] | |
| grid_imgs = [] | |
| img_id = 0 | |
| while (True): | |
| if img_id >= args.num_samples: | |
| break | |
| model_kwargs = data_dict | |
| with th.no_grad(): | |
| samples_lr =sample( | |
| glide_model= model, | |
| glide_options= options, | |
| side_x= 64, | |
| side_y= 64, | |
| prompt=model_kwargs, | |
| batch_size= args.num_samples, | |
| guidance_scale=args.sample_c, | |
| device=torch.device('cuda'), | |
| prediction_respacing= str(sample_step), | |
| upsample_enabled= False, | |
| upsample_temp=0.997, | |
| mode = args.mode, | |
| ) | |
| samples_lr = samples_lr.clamp(-1, 1) | |
| tmp = (127.5*(samples_lr + 1.0)).int() | |
| model_kwargs['low_res'] = tmp/127.5 - 1. | |
| samples_hr =sample( | |
| glide_model= model_up, | |
| glide_options= options_up, | |
| side_x=256, | |
| side_y=256, | |
| prompt=model_kwargs, | |
| batch_size=args.num_samples, | |
| guidance_scale=1, | |
| device=torch.device('cuda'), | |
| prediction_respacing= "fast27", | |
| upsample_enabled=True, | |
| upsample_temp=0.997, | |
| mode = args.mode, | |
| ) | |
| samples_hr = samples_hr | |
| for hr in samples_hr: | |
| hr = 255. * rearrange((hr.cpu().numpy()+1.0)*0.5, 'c h w -> h w c') | |
| sample_img = Image.fromarray(hr.astype(np.uint8)) | |
| sampled_imgs.append(sample_img) | |
| img_id += 1 | |
| grid_imgs.append(samples_hr) | |
| grid = torch.stack(grid_imgs, 0) | |
| grid = rearrange(grid, 'n b c h w -> (n b) c h w') | |
| grid = make_grid(grid, nrow=2) | |
| # to image | |
| grid = 255. * rearrange((grid+1.0)*0.5, 'c h w -> h w c').cpu().numpy() | |
| return Image.fromarray(grid.astype(np.uint8)) | |
| def create_argparser(): | |
| defaults = dict( | |
| data_dir="", | |
| val_data_dir="", | |
| model_path="./base_edge.pt", | |
| sr_model_path="./upsample_edge.pt", | |
| encoder_path="", | |
| schedule_sampler="uniform", | |
| lr=1e-4, | |
| weight_decay=0.0, | |
| lr_anneal_steps=0, | |
| batch_size=2, | |
| microbatch=-1, # -1 disables microbatches | |
| ema_rate="0.9999", # comma-separated list of EMA values | |
| log_interval=100, | |
| save_interval=20000, | |
| resume_checkpoint="", | |
| use_fp16=False, | |
| fp16_scale_growth=1e-3, | |
| sample_c=1., | |
| sample_respacing="100", | |
| uncond_p=0.2, | |
| num_samples=3, | |
| finetune_decoder = False, | |
| mode = '', | |
| ) | |
| defaults_up = defaults | |
| defaults.update(model_and_diffusion_defaults()) | |
| parser = argparse.ArgumentParser() | |
| add_dict_to_argparser(parser, defaults) | |
| defaults_up.update(model_and_diffusion_defaults(True)) | |
| parser_up = argparse.ArgumentParser() | |
| add_dict_to_argparser(parser_up, defaults_up) | |
| return parser, parser_up | |
| image = gr.outputs.Image(type="pil", label="Sampled results") | |
| css = ".output-image{height: 528px !important} .output-carousel .output-image{height:272px !important} a{text-decoration: underline}" | |
| demo = gr.Interface(fn=run, inputs=[ | |
| gr.inputs.Image(type="pil", label="Input Sketch" ) , | |
| # gr.Image(image_mode="L", source="canvas", type="pil", shape=(256,256), invert_colors=False, tool="editor"), | |
| gr.inputs.Radio(label="Input Mode - The type of your input", choices=["mask", "sketch"],default="sketch"), | |
| gr.inputs.Slider(label="sample_c - The strength of classifier-free guidance",default=1.4, minimum=1.0, maximum=2.0), | |
| gr.inputs.Slider(label="Number of samples - How many samples you wish to generate", default=4, step=1, minimum=1, maximum=16), | |
| gr.inputs.Slider(label="Number of Steps - How many steps you want to use", default=100, step=10, minimum=50, maximum=1000), | |
| ], | |
| outputs=[image], | |
| css=css, | |
| title="Generate images from sketches with PITI", | |
| description="<div>By uploading a sketch map or a semantic map and pressing submit, you can generate images based on your input.</div>") | |
| demo.launch(enable_queue=True) | |