Ashoka74 commited on
Commit
6af1a30
Β·
verified Β·
1 Parent(s): abc4fd3

Update gradio_demo.py

Browse files
Files changed (1) hide show
  1. gradio_demo.py +192 -32
gradio_demo.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import os
3
  import math
4
  import gradio as gr
@@ -16,20 +15,32 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerA
16
  from diffusers.models.attention_processor import AttnProcessor2_0
17
  from transformers import CLIPTextModel, CLIPTokenizer
18
  from briarmbg import BriaRMBG
 
 
 
19
  from enum import Enum
20
  from torch.hub import download_url_to_file
 
21
 
22
- from torch.hub import download_url_to_file
 
 
23
  import cv2
24
 
25
  from typing import Optional
26
 
27
  from Depth.depth_anything_v2.dpt import DepthAnythingV2
28
 
 
 
 
 
 
29
 
30
 
31
- # from FLORENCE
32
 
 
 
33
  import supervision as sv
34
  import torch
35
  from PIL import Image
@@ -74,7 +85,7 @@ model.eval()
74
  # Change UNet
75
 
76
  with torch.no_grad():
77
- new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
78
  new_conv_in.weight.zero_()
79
  new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
80
  new_conv_in.bias = unet.conv_in.bias
@@ -95,15 +106,15 @@ def enable_efficient_attention():
95
  print(f"Xformers error: {e}")
96
  print("Falling back to sliced attention")
97
  # Use sliced attention for RTX 2070
98
- # unet.set_attention_slice_size(4)
99
- # vae.set_attention_slice_size(4)
100
  unet.set_attn_processor(AttnProcessor2_0())
101
  vae.set_attn_processor(AttnProcessor2_0())
102
  else:
103
  # Fallback for when xformers is not available
104
  print("Using sliced attention")
105
- # unet.set_attention_slice_size(4)
106
- # vae.set_attention_slice_size(4)
107
  unet.set_attn_processor(AttnProcessor2_0())
108
  vae.set_attn_processor(AttnProcessor2_0())
109
 
@@ -129,12 +140,12 @@ unet.forward = hooked_unet_forward
129
 
130
  # Load
131
 
132
- #model_path = './models/iclight_sd15_fc.safetensors'
133
- model_path = './models/iclight_sd15_fbc.safetensors'
134
 
135
 
136
- if not os.path.exists(model_path):
137
- download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fbc.safetensors', dst=model_path)
138
 
139
  sd_offset = sf.load_file(model_path)
140
  sd_origin = unet.state_dict()
@@ -223,7 +234,7 @@ i2i_pipe = StableDiffusionImg2ImgPipeline(
223
  image_encoder=None
224
  )
225
 
226
- @spaces.GPU(duration=60)
227
  @torch.inference_mode()
228
  def encode_prompt_inner(txt: str):
229
  max_length = tokenizer.model_max_length
@@ -244,7 +255,7 @@ def encode_prompt_inner(txt: str):
244
 
245
  return conds
246
 
247
- @spaces.GPU(duration=60)
248
  @torch.inference_mode()
249
  def encode_prompt_pair(positive_prompt, negative_prompt):
250
  c = encode_prompt_inner(positive_prompt)
@@ -265,7 +276,7 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
265
 
266
  return c, uc
267
 
268
- @spaces.GPU(duration=60)
269
  @torch.inference_mode()
270
  def pytorch2numpy(imgs, quant=True):
271
  results = []
@@ -282,7 +293,7 @@ def pytorch2numpy(imgs, quant=True):
282
  results.append(y)
283
  return results
284
 
285
- @spaces.GPU(duration=60)
286
  @torch.inference_mode()
287
  def numpy2pytorch(imgs):
288
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
@@ -310,7 +321,7 @@ def resize_without_crop(image, target_width, target_height):
310
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
311
  return np.array(resized_image)
312
 
313
- @spaces.GPU(duration=60)
314
  @torch.inference_mode()
315
  def run_rmbg(img, sigma=0.0):
316
  # Convert RGBA to RGB if needed
@@ -454,7 +465,6 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
454
 
455
  return pixels
456
 
457
- @spaces.GPU(duration=60)
458
  @torch.inference_mode()
459
  def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
460
  clear_memory()
@@ -548,7 +558,7 @@ def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_sample
548
  clear_memory()
549
  return pixels, [fg, bg]
550
 
551
- @spaces.GPU(duration=60)
552
  @torch.inference_mode()
553
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
554
  input_fg, matting = run_rmbg(input_fg)
@@ -556,7 +566,7 @@ def process_relight(input_fg, prompt, image_width, image_height, num_samples, se
556
  return input_fg, results
557
 
558
 
559
- @spaces.GPU(duration=60)
560
  @torch.inference_mode()
561
  def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
562
  bg_source = BGSource(bg_source)
@@ -760,17 +770,154 @@ def compress_image(image):
760
  compressed_img = np.array(Image.open("compressed_image.jpg"))
761
  return compressed_img
762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
 
764
  block = gr.Blocks().queue()
765
  with block:
766
- with gr.Tab("Text", visible=False):
767
  with gr.Row():
768
  gr.Markdown("## Product Placement from Text")
769
  with gr.Row():
770
  with gr.Column():
771
  with gr.Row():
772
  input_fg = gr.Image(type="numpy", label="Image", height=480)
773
- output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  with gr.Group():
775
  prompt = gr.Textbox(label="Prompt")
776
  bg_source = gr.Radio(choices=[e.value for e in BGSource],
@@ -811,14 +958,27 @@ with block:
811
  # run_on_click=True, examples_per_page=1024
812
  # )
813
  ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
814
- relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
815
  example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
816
  example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
817
-
818
- with gr.Tab("Background", visible=True):
819
- mask_mover = MaskMover()
 
 
820
 
 
 
 
 
821
 
 
 
 
 
 
 
 
822
  with gr.Row():
823
  gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
824
  gr.Markdown("πŸ’Ύ Generated images are automatically saved to 'outputs' folder")
@@ -937,11 +1097,11 @@ with block:
937
  outputs=[extracted_fg, x_slider, y_slider]
938
  )
939
 
940
- # find_objects_button.click(
941
- # fn=find_objects,
942
- # inputs=[input_image],
943
- # outputs=[extracted_fg]
944
- # )
945
 
946
  get_depth_button.click(
947
  fn=get_depth,
@@ -1101,5 +1261,5 @@ with block:
1101
  )
1102
 
1103
 
1104
-
1105
  block.launch(server_name='0.0.0.0', share=False)
 
 
 
1
  import os
2
  import math
3
  import gradio as gr
 
15
  from diffusers.models.attention_processor import AttnProcessor2_0
16
  from transformers import CLIPTextModel, CLIPTokenizer
17
  from briarmbg import BriaRMBG
18
+ import dds_cloudapi_sdk
19
+ from dds_cloudapi_sdk import Config, Client, TextPrompt
20
+ from dds_cloudapi_sdk.tasks.dinox import DinoxTask
21
  from enum import Enum
22
  from torch.hub import download_url_to_file
23
+ import tempfile
24
 
25
+ from sam2.build_sam import build_sam2
26
+
27
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
28
  import cv2
29
 
30
  from typing import Optional
31
 
32
  from Depth.depth_anything_v2.dpt import DepthAnythingV2
33
 
34
+ import httpx
35
+
36
+ client = httpx.Client(timeout=httpx.Timeout(10.0)) # Set timeout to 10 seconds
37
+
38
+
39
 
40
 
 
41
 
42
+ # from FLORENCE
43
+ import spaces
44
  import supervision as sv
45
  import torch
46
  from PIL import Image
 
85
  # Change UNet
86
 
87
  with torch.no_grad():
88
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
89
  new_conv_in.weight.zero_()
90
  new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
91
  new_conv_in.bias = unet.conv_in.bias
 
106
  print(f"Xformers error: {e}")
107
  print("Falling back to sliced attention")
108
  # Use sliced attention for RTX 2070
109
+ unet.set_attention_slice_size(4)
110
+ vae.set_attention_slice_size(4)
111
  unet.set_attn_processor(AttnProcessor2_0())
112
  vae.set_attn_processor(AttnProcessor2_0())
113
  else:
114
  # Fallback for when xformers is not available
115
  print("Using sliced attention")
116
+ unet.set_attention_slice_size(4)
117
+ vae.set_attention_slice_size(4)
118
  unet.set_attn_processor(AttnProcessor2_0())
119
  vae.set_attn_processor(AttnProcessor2_0())
120
 
 
140
 
141
  # Load
142
 
143
+ model_path = './models/iclight_sd15_fc.safetensors'
144
+ #model_path = './models/iclight_sd15_fbc.safetensors'
145
 
146
 
147
+ # if not os.path.exists(model_path):
148
+ # download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
149
 
150
  sd_offset = sf.load_file(model_path)
151
  sd_origin = unet.state_dict()
 
234
  image_encoder=None
235
  )
236
 
237
+
238
  @torch.inference_mode()
239
  def encode_prompt_inner(txt: str):
240
  max_length = tokenizer.model_max_length
 
255
 
256
  return conds
257
 
258
+
259
  @torch.inference_mode()
260
  def encode_prompt_pair(positive_prompt, negative_prompt):
261
  c = encode_prompt_inner(positive_prompt)
 
276
 
277
  return c, uc
278
 
279
+
280
  @torch.inference_mode()
281
  def pytorch2numpy(imgs, quant=True):
282
  results = []
 
293
  results.append(y)
294
  return results
295
 
296
+
297
  @torch.inference_mode()
298
  def numpy2pytorch(imgs):
299
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
 
321
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
322
  return np.array(resized_image)
323
 
324
+
325
  @torch.inference_mode()
326
  def run_rmbg(img, sigma=0.0):
327
  # Convert RGBA to RGB if needed
 
465
 
466
  return pixels
467
 
 
468
  @torch.inference_mode()
469
  def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
470
  clear_memory()
 
558
  clear_memory()
559
  return pixels, [fg, bg]
560
 
561
+
562
  @torch.inference_mode()
563
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
564
  input_fg, matting = run_rmbg(input_fg)
 
566
  return input_fg, results
567
 
568
 
569
+
570
  @torch.inference_mode()
571
  def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
572
  bg_source = BGSource(bg_source)
 
770
  compressed_img = np.array(Image.open("compressed_image.jpg"))
771
  return compressed_img
772
 
773
+ @spaces.GPU(duration=60)
774
+ @torch.inference_mode()
775
+ def process_image(input_image, input_text):
776
+ """Main processing function for the Gradio interface"""
777
+
778
+ # Initialize configs
779
+ API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
780
+ SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
781
+ SAM2_MODEL_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs/sam2_hiera_l.yaml")
782
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
783
+ OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
784
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
785
+
786
+ # Initialize DDS client
787
+ config = Config(API_TOKEN)
788
+ client = Client(config)
789
+
790
+ # Process classes from text prompt
791
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
792
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
793
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
794
+
795
+ # Save input image to temp file and get URL
796
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
797
+ cv2.imwrite(tmpfile.name, input_image)
798
+ image_url = client.upload_file(tmpfile.name)
799
+ os.remove(tmpfile.name)
800
+
801
+ # Run DINO-X detection
802
+ task = DinoxTask(
803
+ image_url=image_url,
804
+ prompts=[TextPrompt(text=input_text)]
805
+ )
806
+ client.run_task(task)
807
+ result = task.result
808
+ objects = result.objects
809
+
810
+ # Process detection results
811
+ input_boxes = []
812
+ confidences = []
813
+ class_names = []
814
+ class_ids = []
815
+
816
+ for obj in objects:
817
+ input_boxes.append(obj.bbox)
818
+ confidences.append(obj.score)
819
+ cls_name = obj.category.lower().strip()
820
+ class_names.append(cls_name)
821
+ class_ids.append(class_name_to_id[cls_name])
822
+
823
+ input_boxes = np.array(input_boxes)
824
+ class_ids = np.array(class_ids)
825
+
826
+ # Initialize SAM2
827
+ torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
828
+ if torch.cuda.get_device_properties(0).major >= 8:
829
+ torch.backends.cuda.matmul.allow_tf32 = True
830
+ torch.backends.cudnn.allow_tf32 = True
831
+
832
+ sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
833
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
834
+ sam2_predictor.set_image(input_image)
835
+
836
+ # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
837
+
838
+
839
+ # Get masks from SAM2
840
+ masks, scores, logits = sam2_predictor.predict(
841
+ point_coords=None,
842
+ point_labels=None,
843
+ box=input_boxes,
844
+ multimask_output=False,
845
+ )
846
+ if masks.ndim == 4:
847
+ masks = masks.squeeze(1)
848
+
849
+ # Create visualization
850
+ labels = [f"{class_name} {confidence:.2f}"
851
+ for class_name, confidence in zip(class_names, confidences)]
852
+
853
+ detections = sv.Detections(
854
+ xyxy=input_boxes,
855
+ mask=masks.astype(bool),
856
+ class_id=class_ids
857
+ )
858
+
859
+ box_annotator = sv.BoxAnnotator()
860
+ label_annotator = sv.LabelAnnotator()
861
+ mask_annotator = sv.MaskAnnotator()
862
+
863
+ annotated_frame = input_image.copy()
864
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
865
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
866
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
867
+
868
+ # Create transparent mask for first detected object
869
+ if len(detections) > 0:
870
+ # Get first mask
871
+ first_mask = detections.mask[0]
872
+
873
+ # Get original RGB image
874
+ img = input_image.copy()
875
+ H, W, C = img.shape
876
+
877
+ # Create RGBA image
878
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
879
+ alpha[first_mask] = 255
880
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
881
+
882
+ # Crop to mask bounds to minimize image size
883
+ y_indices, x_indices = np.where(first_mask)
884
+ y_min, y_max = y_indices.min(), y_indices.max()
885
+ x_min, x_max = x_indices.min(), x_indices.max()
886
+
887
+ # Crop the RGBA image
888
+ cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
889
+
890
+ # Set extracted foreground for mask mover
891
+ mask_mover.set_extracted_fg(cropped_rgba)
892
+
893
+ return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
894
+
895
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
896
+
897
 
898
  block = gr.Blocks().queue()
899
  with block:
900
+ with gr.Tab("Text"):
901
  with gr.Row():
902
  gr.Markdown("## Product Placement from Text")
903
  with gr.Row():
904
  with gr.Column():
905
  with gr.Row():
906
  input_fg = gr.Image(type="numpy", label="Image", height=480)
907
+ with gr.Row():
908
+ with gr.Group():
909
+ find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
910
+ text_prompt = gr.Textbox(
911
+ label="Text Prompt",
912
+ placeholder="Enter object classes separated by periods (e.g. 'car . person .')",
913
+ value="couch . table ."
914
+ )
915
+ extract_button = gr.Button(value="(Option 2) Remove Background")
916
+ with gr.Row():
917
+ extracted_objects = gr.Image(type="numpy", label="Extracted Foreground", height=480)
918
+ extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
919
+
920
+ # output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
921
  with gr.Group():
922
  prompt = gr.Textbox(label="Prompt")
923
  bg_source = gr.Radio(choices=[e.value for e in BGSource],
 
958
  # run_on_click=True, examples_per_page=1024
959
  # )
960
  ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
961
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[extracted_fg, result_gallery])
962
  example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
963
  example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
964
+ find_objects_button.click(
965
+ fn=process_image,
966
+ inputs=[input_fg, text_prompt],
967
+ outputs=[extracted_objects, extracted_fg]
968
+ )
969
 
970
+ with gr.Tab("Background", visible=False):
971
+ # empty cache
972
+
973
+ mask_mover = MaskMover()
974
 
975
+ # with torch.no_grad():
976
+ # # Update the input channels to 12
977
+ # new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) # Changed from 8 to 12
978
+ # new_conv_in.weight.zero_()
979
+ # new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
980
+ # new_conv_in.bias = unet.conv_in.bias
981
+ # unet.conv_in = new_conv_in
982
  with gr.Row():
983
  gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
984
  gr.Markdown("πŸ’Ύ Generated images are automatically saved to 'outputs' folder")
 
1097
  outputs=[extracted_fg, x_slider, y_slider]
1098
  )
1099
 
1100
+ find_objects_button.click(
1101
+ fn=process_image,
1102
+ inputs=[input_image, text_prompt],
1103
+ outputs=[extracted_objects, extracted_fg, x_slider, y_slider]
1104
+ )
1105
 
1106
  get_depth_button.click(
1107
  fn=get_depth,
 
1261
  )
1262
 
1263
 
 
1264
  block.launch(server_name='0.0.0.0', share=False)
1265
+