Spaces:
Runtime error
Runtime error
# ************************************************************************* | |
# Copyright (2023) Bytedance Inc. | |
# | |
# Copyright (2023) DragDiffusion Authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ************************************************************************* | |
# run results of DragDiffusion | |
import argparse | |
import os | |
import datetime | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import pickle | |
import PIL | |
from PIL import Image | |
from copy import deepcopy | |
from einops import rearrange | |
from types import SimpleNamespace | |
from diffusers import DDIMScheduler, AutoencoderKL | |
from torchvision.utils import save_image | |
from pytorch_lightning import seed_everything | |
import sys | |
sys.path.insert(0, '../') | |
from drag_pipeline import DragPipeline | |
from utils.drag_utils import drag_diffusion_update | |
from utils.attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl | |
def preprocess_image(image, | |
device): | |
image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1] | |
image = rearrange(image, "h w c -> 1 c h w") | |
image = image.to(device) | |
return image | |
# copy the run_drag function to here | |
def run_drag(source_image, | |
# image_with_clicks, | |
mask, | |
prompt, | |
points, | |
inversion_strength, | |
lam, | |
latent_lr, | |
unet_feature_idx, | |
n_pix_step, | |
model_path, | |
vae_path, | |
lora_path, | |
start_step, | |
start_layer, | |
# save_dir="./results" | |
): | |
# initialize model | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, | |
beta_schedule="scaled_linear", clip_sample=False, | |
set_alpha_to_one=False, steps_offset=1) | |
model = DragPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) | |
# call this function to override unet forward function, | |
# so that intermediate features are returned after forward | |
model.modify_unet_forward() | |
# set vae | |
if vae_path != "default": | |
model.vae = AutoencoderKL.from_pretrained( | |
vae_path | |
).to(model.vae.device, model.vae.dtype) | |
# initialize parameters | |
seed = 42 # random seed used by a lot of people for unknown reason | |
seed_everything(seed) | |
args = SimpleNamespace() | |
args.prompt = prompt | |
args.points = points | |
args.n_inference_step = 50 | |
args.n_actual_inference_step = round(inversion_strength * args.n_inference_step) | |
args.guidance_scale = 1.0 | |
args.unet_feature_idx = [unet_feature_idx] | |
args.r_m = 1 | |
args.r_p = 3 | |
args.lam = lam | |
args.lr = latent_lr | |
args.n_pix_step = n_pix_step | |
full_h, full_w = source_image.shape[:2] | |
args.sup_res_h = int(0.5*full_h) | |
args.sup_res_w = int(0.5*full_w) | |
print(args) | |
source_image = preprocess_image(source_image, device) | |
# image_with_clicks = preprocess_image(image_with_clicks, device) | |
# set lora | |
if lora_path == "": | |
print("applying default parameters") | |
model.unet.set_default_attn_processor() | |
else: | |
print("applying lora: " + lora_path) | |
model.unet.load_attn_procs(lora_path) | |
# invert the source image | |
# the latent code resolution is too small, only 64*64 | |
invert_code = model.invert(source_image, | |
prompt, | |
guidance_scale=args.guidance_scale, | |
num_inference_steps=args.n_inference_step, | |
num_actual_inference_steps=args.n_actual_inference_step) | |
mask = torch.from_numpy(mask).float() / 255. | |
mask[mask > 0.0] = 1.0 | |
mask = rearrange(mask, "h w -> 1 1 h w").cuda() | |
mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest") | |
handle_points = [] | |
target_points = [] | |
# here, the point is in x,y coordinate | |
for idx, point in enumerate(points): | |
cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w]) | |
cur_point = torch.round(cur_point) | |
if idx % 2 == 0: | |
handle_points.append(cur_point) | |
else: | |
target_points.append(cur_point) | |
print('handle points:', handle_points) | |
print('target points:', target_points) | |
init_code = invert_code | |
init_code_orig = deepcopy(init_code) | |
model.scheduler.set_timesteps(args.n_inference_step) | |
t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] | |
# feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] | |
# update according to the given supervision | |
updated_init_code = drag_diffusion_update(model, init_code, | |
None, t, handle_points, target_points, mask, args) | |
# hijack the attention module | |
# inject the reference branch to guide the generation | |
editor = MutualSelfAttentionControl(start_step=start_step, | |
start_layer=start_layer, | |
total_steps=args.n_inference_step, | |
guidance_scale=args.guidance_scale) | |
if lora_path == "": | |
register_attention_editor_diffusers(model, editor, attn_processor='attn_proc') | |
else: | |
register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc') | |
# inference the synthesized image | |
gen_image = model( | |
prompt=args.prompt, | |
batch_size=2, | |
latents=torch.cat([init_code_orig, updated_init_code], dim=0), | |
guidance_scale=args.guidance_scale, | |
num_inference_steps=args.n_inference_step, | |
num_actual_inference_steps=args.n_actual_inference_step | |
)[1].unsqueeze(dim=0) | |
# resize gen_image into the size of source_image | |
# we do this because shape of gen_image will be rounded to multipliers of 8 | |
gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear') | |
# save the original image, user editing instructions, synthesized image | |
# save_result = torch.cat([ | |
# source_image * 0.5 + 0.5, | |
# torch.ones((1,3,full_h,25)).cuda(), | |
# image_with_clicks * 0.5 + 0.5, | |
# torch.ones((1,3,full_h,25)).cuda(), | |
# gen_image[0:1] | |
# ], dim=-1) | |
# if not os.path.isdir(save_dir): | |
# os.mkdir(save_dir) | |
# save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") | |
# save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) | |
out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] | |
out_image = (out_image * 255).astype(np.uint8) | |
return out_image | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="setting arguments") | |
parser.add_argument('--lora_steps', type=int, help='number of lora fine-tuning steps') | |
parser.add_argument('--inv_strength', type=float, help='inversion strength') | |
parser.add_argument('--latent_lr', type=float, default=0.01, help='latent learning rate') | |
parser.add_argument('--unet_feature_idx', type=int, default=3, help='feature idx of unet features') | |
args = parser.parse_args() | |
all_category = [ | |
'art_work', | |
'land_scape', | |
'building_city_view', | |
'building_countryside_view', | |
'animals', | |
'human_head', | |
'human_upper_body', | |
'human_full_body', | |
'interior_design', | |
'other_objects', | |
] | |
# assume root_dir and lora_dir are valid directory | |
root_dir = 'drag_bench_data' | |
lora_dir = 'drag_bench_lora' | |
result_dir = 'drag_diffusion_res' + \ | |
'_' + str(args.lora_steps) + \ | |
'_' + str(args.inv_strength) + \ | |
'_' + str(args.latent_lr) + \ | |
'_' + str(args.unet_feature_idx) | |
# mkdir if necessary | |
if not os.path.isdir(result_dir): | |
os.mkdir(result_dir) | |
for cat in all_category: | |
os.mkdir(os.path.join(result_dir,cat)) | |
for cat in all_category: | |
file_dir = os.path.join(root_dir, cat) | |
for sample_name in os.listdir(file_dir): | |
if sample_name == '.DS_Store': | |
continue | |
sample_path = os.path.join(file_dir, sample_name) | |
# read image file | |
source_image = Image.open(os.path.join(sample_path, 'original_image.png')) | |
source_image = np.array(source_image) | |
# load meta data | |
with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f: | |
meta_data = pickle.load(f) | |
prompt = meta_data['prompt'] | |
mask = meta_data['mask'] | |
points = meta_data['points'] | |
# load lora | |
lora_path = os.path.join(lora_dir, cat, sample_name, str(args.lora_steps)) | |
print("applying lora: " + lora_path) | |
out_image = run_drag( | |
source_image, | |
mask, | |
prompt, | |
points, | |
inversion_strength=args.inv_strength, | |
lam=0.1, | |
latent_lr=args.latent_lr, | |
unet_feature_idx=args.unet_feature_idx, | |
n_pix_step=80, | |
model_path="runwayml/stable-diffusion-v1-5", | |
vae_path="default", | |
lora_path=lora_path, | |
start_step=0, | |
start_layer=10, | |
) | |
save_dir = os.path.join(result_dir, cat, sample_name) | |
if not os.path.isdir(save_dir): | |
os.mkdir(save_dir) | |
Image.fromarray(out_image).save(os.path.join(save_dir, 'dragged_image.png')) | |