mahdideveloepr commited on
Commit
9451bf8
·
verified ·
1 Parent(s): d1960aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +433 -1
app.py CHANGED
@@ -1 +1,433 @@
1
- import os;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import safetensors.torch as sf
7
+ import db_examples
8
+
9
+ from PIL import Image
10
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
11
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
12
+ from diffusers.models.attention_processor import AttnProcessor2_0
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+ from briarmbg import BriaRMBG
15
+ from enum import Enum
16
+ from torch.hub import download_url_to_file
17
+
18
+
19
+ # 'stablediffusionapi/realistic-vision-v51'
20
+ # 'runwayml/stable-diffusion-v1-5'
21
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
22
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
23
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
24
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
25
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
26
+ rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
27
+
28
+ # Change UNet
29
+
30
+ with torch.no_grad():
31
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
32
+ new_conv_in.weight.zero_()
33
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
34
+ new_conv_in.bias = unet.conv_in.bias
35
+ unet.conv_in = new_conv_in
36
+
37
+ unet_original_forward = unet.forward
38
+
39
+
40
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
41
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
42
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
43
+ new_sample = torch.cat([sample, c_concat], dim=1)
44
+ kwargs['cross_attention_kwargs'] = {}
45
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
46
+
47
+
48
+ unet.forward = hooked_unet_forward
49
+
50
+ # Load
51
+
52
+ model_path = './models/iclight_sd15_fc.safetensors'
53
+
54
+ if not os.path.exists(model_path):
55
+ download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
56
+
57
+ sd_offset = sf.load_file(model_path)
58
+ sd_origin = unet.state_dict()
59
+ keys = sd_origin.keys()
60
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
61
+ unet.load_state_dict(sd_merged, strict=True)
62
+ del sd_offset, sd_origin, sd_merged, keys
63
+
64
+ # Device
65
+
66
+ device = torch.device('cuda')
67
+ text_encoder = text_encoder.to(device=device, dtype=torch.float16)
68
+ vae = vae.to(device=device, dtype=torch.bfloat16)
69
+ unet = unet.to(device=device, dtype=torch.float16)
70
+ rmbg = rmbg.to(device=device, dtype=torch.float32)
71
+
72
+ # SDP
73
+
74
+ unet.set_attn_processor(AttnProcessor2_0())
75
+ vae.set_attn_processor(AttnProcessor2_0())
76
+
77
+ # Samplers
78
+
79
+ ddim_scheduler = DDIMScheduler(
80
+ num_train_timesteps=1000,
81
+ beta_start=0.00085,
82
+ beta_end=0.012,
83
+ beta_schedule="scaled_linear",
84
+ clip_sample=False,
85
+ set_alpha_to_one=False,
86
+ steps_offset=1,
87
+ )
88
+
89
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
90
+ num_train_timesteps=1000,
91
+ beta_start=0.00085,
92
+ beta_end=0.012,
93
+ steps_offset=1
94
+ )
95
+
96
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
97
+ num_train_timesteps=1000,
98
+ beta_start=0.00085,
99
+ beta_end=0.012,
100
+ algorithm_type="sde-dpmsolver++",
101
+ use_karras_sigmas=True,
102
+ steps_offset=1
103
+ )
104
+
105
+ # Pipelines
106
+
107
+ t2i_pipe = StableDiffusionPipeline(
108
+ vae=vae,
109
+ text_encoder=text_encoder,
110
+ tokenizer=tokenizer,
111
+ unet=unet,
112
+ scheduler=dpmpp_2m_sde_karras_scheduler,
113
+ safety_checker=None,
114
+ requires_safety_checker=False,
115
+ feature_extractor=None,
116
+ image_encoder=None
117
+ )
118
+
119
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
120
+ vae=vae,
121
+ text_encoder=text_encoder,
122
+ tokenizer=tokenizer,
123
+ unet=unet,
124
+ scheduler=dpmpp_2m_sde_karras_scheduler,
125
+ safety_checker=None,
126
+ requires_safety_checker=False,
127
+ feature_extractor=None,
128
+ image_encoder=None
129
+ )
130
+
131
+
132
+ @torch.inference_mode()
133
+ def encode_prompt_inner(txt: str):
134
+ max_length = tokenizer.model_max_length
135
+ chunk_length = tokenizer.model_max_length - 2
136
+ id_start = tokenizer.bos_token_id
137
+ id_end = tokenizer.eos_token_id
138
+ id_pad = id_end
139
+
140
+ def pad(x, p, i):
141
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
142
+
143
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
144
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
145
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
146
+
147
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
148
+ conds = text_encoder(token_ids).last_hidden_state
149
+
150
+ return conds
151
+
152
+
153
+ @torch.inference_mode()
154
+ def encode_prompt_pair(positive_prompt, negative_prompt):
155
+ c = encode_prompt_inner(positive_prompt)
156
+ uc = encode_prompt_inner(negative_prompt)
157
+
158
+ c_len = float(len(c))
159
+ uc_len = float(len(uc))
160
+ max_count = max(c_len, uc_len)
161
+ c_repeat = int(math.ceil(max_count / c_len))
162
+ uc_repeat = int(math.ceil(max_count / uc_len))
163
+ max_chunk = max(len(c), len(uc))
164
+
165
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
166
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
167
+
168
+ c = torch.cat([p[None, ...] for p in c], dim=1)
169
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
170
+
171
+ return c, uc
172
+
173
+
174
+ @torch.inference_mode()
175
+ def pytorch2numpy(imgs, quant=True):
176
+ results = []
177
+ for x in imgs:
178
+ y = x.movedim(0, -1)
179
+
180
+ if quant:
181
+ y = y * 127.5 + 127.5
182
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
183
+ else:
184
+ y = y * 0.5 + 0.5
185
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
186
+
187
+ results.append(y)
188
+ return results
189
+
190
+
191
+ @torch.inference_mode()
192
+ def numpy2pytorch(imgs):
193
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
194
+ h = h.movedim(-1, 1)
195
+ return h
196
+
197
+
198
+ def resize_and_center_crop(image, target_width, target_height):
199
+ pil_image = Image.fromarray(image)
200
+ original_width, original_height = pil_image.size
201
+ scale_factor = max(target_width / original_width, target_height / original_height)
202
+ resized_width = int(round(original_width * scale_factor))
203
+ resized_height = int(round(original_height * scale_factor))
204
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
205
+ left = (resized_width - target_width) / 2
206
+ top = (resized_height - target_height) / 2
207
+ right = (resized_width + target_width) / 2
208
+ bottom = (resized_height + target_height) / 2
209
+ cropped_image = resized_image.crop((left, top, right, bottom))
210
+ return np.array(cropped_image)
211
+
212
+
213
+ def resize_without_crop(image, target_width, target_height):
214
+ pil_image = Image.fromarray(image)
215
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
216
+ return np.array(resized_image)
217
+
218
+
219
+ @torch.inference_mode()
220
+ def run_rmbg(img, sigma=0.0):
221
+ H, W, C = img.shape
222
+ assert C == 3
223
+ k = (256.0 / float(H * W)) ** 0.5
224
+ feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
225
+ feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
226
+ alpha = rmbg(feed)[0][0]
227
+ alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
228
+ alpha = alpha.movedim(1, -1)[0]
229
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
230
+ result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
231
+ return result.clip(0, 255).astype(np.uint8), alpha
232
+
233
+
234
+ @torch.inference_mode()
235
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
236
+ bg_source = BGSource(bg_source)
237
+ input_bg = None
238
+
239
+ if bg_source == BGSource.NONE:
240
+ pass
241
+ elif bg_source == BGSource.LEFT:
242
+ gradient = np.linspace(255, 0, image_width)
243
+ image = np.tile(gradient, (image_height, 1))
244
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
245
+ elif bg_source == BGSource.RIGHT:
246
+ gradient = np.linspace(0, 255, image_width)
247
+ image = np.tile(gradient, (image_height, 1))
248
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
249
+ elif bg_source == BGSource.TOP:
250
+ gradient = np.linspace(255, 0, image_height)[:, None]
251
+ image = np.tile(gradient, (1, image_width))
252
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
253
+ elif bg_source == BGSource.BOTTOM:
254
+ gradient = np.linspace(0, 255, image_height)[:, None]
255
+ image = np.tile(gradient, (1, image_width))
256
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
257
+ else:
258
+ raise 'Wrong initial latent!'
259
+
260
+ rng = torch.Generator(device=device).manual_seed(int(seed))
261
+
262
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
263
+
264
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
265
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
266
+
267
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
268
+
269
+ if input_bg is None:
270
+ latents = t2i_pipe(
271
+ prompt_embeds=conds,
272
+ negative_prompt_embeds=unconds,
273
+ width=image_width,
274
+ height=image_height,
275
+ num_inference_steps=steps,
276
+ num_images_per_prompt=num_samples,
277
+ generator=rng,
278
+ output_type='latent',
279
+ guidance_scale=cfg,
280
+ cross_attention_kwargs={'concat_conds': concat_conds},
281
+ ).images.to(vae.dtype) / vae.config.scaling_factor
282
+ else:
283
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
284
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
285
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
286
+ latents = i2i_pipe(
287
+ image=bg_latent,
288
+ strength=lowres_denoise,
289
+ prompt_embeds=conds,
290
+ negative_prompt_embeds=unconds,
291
+ width=image_width,
292
+ height=image_height,
293
+ num_inference_steps=int(round(steps / lowres_denoise)),
294
+ num_images_per_prompt=num_samples,
295
+ generator=rng,
296
+ output_type='latent',
297
+ guidance_scale=cfg,
298
+ cross_attention_kwargs={'concat_conds': concat_conds},
299
+ ).images.to(vae.dtype) / vae.config.scaling_factor
300
+
301
+ pixels = vae.decode(latents).sample
302
+ pixels = pytorch2numpy(pixels)
303
+ pixels = [resize_without_crop(
304
+ image=p,
305
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
306
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
307
+ for p in pixels]
308
+
309
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
310
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
311
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
312
+
313
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
314
+
315
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
316
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
317
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
318
+
319
+ latents = i2i_pipe(
320
+ image=latents,
321
+ strength=highres_denoise,
322
+ prompt_embeds=conds,
323
+ negative_prompt_embeds=unconds,
324
+ width=image_width,
325
+ height=image_height,
326
+ num_inference_steps=int(round(steps / highres_denoise)),
327
+ num_images_per_prompt=num_samples,
328
+ generator=rng,
329
+ output_type='latent',
330
+ guidance_scale=cfg,
331
+ cross_attention_kwargs={'concat_conds': concat_conds},
332
+ ).images.to(vae.dtype) / vae.config.scaling_factor
333
+
334
+ pixels = vae.decode(latents).sample
335
+
336
+ return pytorch2numpy(pixels)
337
+
338
+
339
+ @torch.inference_mode()
340
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
341
+ input_fg, matting = run_rmbg(input_fg)
342
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
343
+ return input_fg, results
344
+
345
+
346
+ quick_prompts = [
347
+ 'sunshine from window',
348
+ 'neon light, city',
349
+ 'sunset over sea',
350
+ 'golden time',
351
+ 'sci-fi RGB glowing, cyberpunk',
352
+ 'natural lighting',
353
+ 'warm atmosphere, at home, bedroom',
354
+ 'magic lit',
355
+ 'evil, gothic, Yharnam',
356
+ 'light and shadow',
357
+ 'shadow from window',
358
+ 'soft studio lighting',
359
+ 'home atmosphere, cozy bedroom illumination',
360
+ 'neon, Wong Kar-wai, warm'
361
+ ]
362
+ quick_prompts = [[x] for x in quick_prompts]
363
+
364
+
365
+ quick_subjects = [
366
+ 'beautiful woman, detailed face',
367
+ 'handsome man, detailed face',
368
+ ]
369
+ quick_subjects = [[x] for x in quick_subjects]
370
+
371
+
372
+ class BGSource(Enum):
373
+ NONE = "None"
374
+ LEFT = "Left Light"
375
+ RIGHT = "Right Light"
376
+ TOP = "Top Light"
377
+ BOTTOM = "Bottom Light"
378
+
379
+
380
+ block = gr.Blocks().queue()
381
+ with block:
382
+ with gr.Row():
383
+ gr.Markdown("## IC-Light (Relighting with Foreground Condition)")
384
+ with gr.Row():
385
+ with gr.Column():
386
+ with gr.Row():
387
+ input_fg = gr.Image(source='upload', type="numpy", label="Image", height=480)
388
+ output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
389
+ prompt = gr.Textbox(label="Prompt")
390
+ bg_source = gr.Radio(choices=[e.value for e in BGSource],
391
+ value=BGSource.NONE.value,
392
+ label="Lighting Preference (Initial Latent)", type='value')
393
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
394
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
395
+ relight_button = gr.Button(value="Relight")
396
+
397
+ with gr.Group():
398
+ with gr.Row():
399
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
400
+ seed = gr.Number(label="Seed", value=12345, precision=0)
401
+
402
+ with gr.Row():
403
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
404
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
405
+
406
+ with gr.Accordion("Advanced options", open=False):
407
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
408
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01)
409
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
410
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
411
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
412
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
413
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
414
+ with gr.Column():
415
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
416
+ with gr.Row():
417
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result')
418
+ gr.Examples(
419
+ fn=lambda *args: ([args[-1]], None),
420
+ examples=db_examples.foreground_conditioned_examples,
421
+ inputs=[
422
+ input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
423
+ ],
424
+ outputs=[result_gallery, output_bg],
425
+ run_on_click=True, examples_per_page=1024
426
+ )
427
+ ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
428
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
429
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
430
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
431
+
432
+
433
+ block.launch(server_name='0.0.0.0')