Ashoka74 commited on
Commit
9f4d9ca
Β·
verified Β·
1 Parent(s): 5069f83

Update gradio_demo.py

Browse files
Files changed (1) hide show
  1. gradio_demo.py +11 -6
gradio_demo.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import math
3
  import gradio as gr
@@ -255,7 +256,7 @@ def encode_prompt_inner(txt: str):
255
 
256
  return conds
257
 
258
-
259
  @torch.inference_mode()
260
  def encode_prompt_pair(positive_prompt, negative_prompt):
261
  c = encode_prompt_inner(positive_prompt)
@@ -276,7 +277,7 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
276
 
277
  return c, uc
278
 
279
-
280
  @torch.inference_mode()
281
  def pytorch2numpy(imgs, quant=True):
282
  results = []
@@ -293,7 +294,7 @@ def pytorch2numpy(imgs, quant=True):
293
  results.append(y)
294
  return results
295
 
296
-
297
  @torch.inference_mode()
298
  def numpy2pytorch(imgs):
299
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
@@ -321,7 +322,7 @@ def resize_without_crop(image, target_width, target_height):
321
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
322
  return np.array(resized_image)
323
 
324
-
325
  @torch.inference_mode()
326
  def run_rmbg(img, sigma=0.0):
327
  # Convert RGBA to RGB if needed
@@ -346,6 +347,8 @@ def run_rmbg(img, sigma=0.0):
346
  rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
347
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
348
  return result.clip(0, 255).astype(np.uint8), rgba
 
 
349
  @torch.inference_mode()
350
  def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
351
  clear_memory()
@@ -465,6 +468,8 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
465
 
466
  return pixels
467
 
 
 
468
  @torch.inference_mode()
469
  def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
470
  clear_memory()
@@ -558,7 +563,7 @@ def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_sample
558
  clear_memory()
559
  return pixels, [fg, bg]
560
 
561
-
562
  @torch.inference_mode()
563
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
564
  input_fg, matting = run_rmbg(input_fg)
@@ -566,7 +571,7 @@ def process_relight(input_fg, prompt, image_width, image_height, num_samples, se
566
  return input_fg, results
567
 
568
 
569
-
570
  @torch.inference_mode()
571
  def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
572
  bg_source = BGSource(bg_source)
 
1
+ import spaces
2
  import os
3
  import math
4
  import gradio as gr
 
256
 
257
  return conds
258
 
259
+ @spaces.GPU(duration=60)
260
  @torch.inference_mode()
261
  def encode_prompt_pair(positive_prompt, negative_prompt):
262
  c = encode_prompt_inner(positive_prompt)
 
277
 
278
  return c, uc
279
 
280
+ @spaces.GPU(duration=60)
281
  @torch.inference_mode()
282
  def pytorch2numpy(imgs, quant=True):
283
  results = []
 
294
  results.append(y)
295
  return results
296
 
297
+ @spaces.GPU(duration=60)
298
  @torch.inference_mode()
299
  def numpy2pytorch(imgs):
300
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
 
322
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
323
  return np.array(resized_image)
324
 
325
+ @spaces.GPU(duration=60)
326
  @torch.inference_mode()
327
  def run_rmbg(img, sigma=0.0):
328
  # Convert RGBA to RGB if needed
 
347
  rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
348
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
349
  return result.clip(0, 255).astype(np.uint8), rgba
350
+
351
+ @spaces.GPU(duration=60)
352
  @torch.inference_mode()
353
  def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
354
  clear_memory()
 
468
 
469
  return pixels
470
 
471
+
472
+ @spaces.GPU(duration=60)
473
  @torch.inference_mode()
474
  def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
475
  clear_memory()
 
563
  clear_memory()
564
  return pixels, [fg, bg]
565
 
566
+ @spaces.GPU(duration=60)
567
  @torch.inference_mode()
568
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
569
  input_fg, matting = run_rmbg(input_fg)
 
571
  return input_fg, results
572
 
573
 
574
+ @spaces.GPU(duration=60)
575
  @torch.inference_mode()
576
  def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
577
  bg_source = BGSource(bg_source)