Spaces:
Runtime error
Runtime error
Update gradio_demo.py
Browse files- gradio_demo.py +192 -32
gradio_demo.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import spaces
|
2 |
import os
|
3 |
import math
|
4 |
import gradio as gr
|
@@ -16,20 +15,32 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerA
|
|
16 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
17 |
from transformers import CLIPTextModel, CLIPTokenizer
|
18 |
from briarmbg import BriaRMBG
|
|
|
|
|
|
|
19 |
from enum import Enum
|
20 |
from torch.hub import download_url_to_file
|
|
|
21 |
|
22 |
-
from
|
|
|
|
|
23 |
import cv2
|
24 |
|
25 |
from typing import Optional
|
26 |
|
27 |
from Depth.depth_anything_v2.dpt import DepthAnythingV2
|
28 |
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
-
# from FLORENCE
|
32 |
|
|
|
|
|
33 |
import supervision as sv
|
34 |
import torch
|
35 |
from PIL import Image
|
@@ -74,7 +85,7 @@ model.eval()
|
|
74 |
# Change UNet
|
75 |
|
76 |
with torch.no_grad():
|
77 |
-
new_conv_in = torch.nn.Conv2d(
|
78 |
new_conv_in.weight.zero_()
|
79 |
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
|
80 |
new_conv_in.bias = unet.conv_in.bias
|
@@ -95,15 +106,15 @@ def enable_efficient_attention():
|
|
95 |
print(f"Xformers error: {e}")
|
96 |
print("Falling back to sliced attention")
|
97 |
# Use sliced attention for RTX 2070
|
98 |
-
|
99 |
-
|
100 |
unet.set_attn_processor(AttnProcessor2_0())
|
101 |
vae.set_attn_processor(AttnProcessor2_0())
|
102 |
else:
|
103 |
# Fallback for when xformers is not available
|
104 |
print("Using sliced attention")
|
105 |
-
|
106 |
-
|
107 |
unet.set_attn_processor(AttnProcessor2_0())
|
108 |
vae.set_attn_processor(AttnProcessor2_0())
|
109 |
|
@@ -129,12 +140,12 @@ unet.forward = hooked_unet_forward
|
|
129 |
|
130 |
# Load
|
131 |
|
132 |
-
|
133 |
-
model_path = './models/iclight_sd15_fbc.safetensors'
|
134 |
|
135 |
|
136 |
-
if not os.path.exists(model_path):
|
137 |
-
|
138 |
|
139 |
sd_offset = sf.load_file(model_path)
|
140 |
sd_origin = unet.state_dict()
|
@@ -223,7 +234,7 @@ i2i_pipe = StableDiffusionImg2ImgPipeline(
|
|
223 |
image_encoder=None
|
224 |
)
|
225 |
|
226 |
-
|
227 |
@torch.inference_mode()
|
228 |
def encode_prompt_inner(txt: str):
|
229 |
max_length = tokenizer.model_max_length
|
@@ -244,7 +255,7 @@ def encode_prompt_inner(txt: str):
|
|
244 |
|
245 |
return conds
|
246 |
|
247 |
-
|
248 |
@torch.inference_mode()
|
249 |
def encode_prompt_pair(positive_prompt, negative_prompt):
|
250 |
c = encode_prompt_inner(positive_prompt)
|
@@ -265,7 +276,7 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
|
|
265 |
|
266 |
return c, uc
|
267 |
|
268 |
-
|
269 |
@torch.inference_mode()
|
270 |
def pytorch2numpy(imgs, quant=True):
|
271 |
results = []
|
@@ -282,7 +293,7 @@ def pytorch2numpy(imgs, quant=True):
|
|
282 |
results.append(y)
|
283 |
return results
|
284 |
|
285 |
-
|
286 |
@torch.inference_mode()
|
287 |
def numpy2pytorch(imgs):
|
288 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
|
@@ -310,7 +321,7 @@ def resize_without_crop(image, target_width, target_height):
|
|
310 |
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
311 |
return np.array(resized_image)
|
312 |
|
313 |
-
|
314 |
@torch.inference_mode()
|
315 |
def run_rmbg(img, sigma=0.0):
|
316 |
# Convert RGBA to RGB if needed
|
@@ -454,7 +465,6 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
|
|
454 |
|
455 |
return pixels
|
456 |
|
457 |
-
@spaces.GPU(duration=60)
|
458 |
@torch.inference_mode()
|
459 |
def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
|
460 |
clear_memory()
|
@@ -548,7 +558,7 @@ def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_sample
|
|
548 |
clear_memory()
|
549 |
return pixels, [fg, bg]
|
550 |
|
551 |
-
|
552 |
@torch.inference_mode()
|
553 |
def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
554 |
input_fg, matting = run_rmbg(input_fg)
|
@@ -556,7 +566,7 @@ def process_relight(input_fg, prompt, image_width, image_height, num_samples, se
|
|
556 |
return input_fg, results
|
557 |
|
558 |
|
559 |
-
|
560 |
@torch.inference_mode()
|
561 |
def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
|
562 |
bg_source = BGSource(bg_source)
|
@@ -760,17 +770,154 @@ def compress_image(image):
|
|
760 |
compressed_img = np.array(Image.open("compressed_image.jpg"))
|
761 |
return compressed_img
|
762 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
763 |
|
764 |
block = gr.Blocks().queue()
|
765 |
with block:
|
766 |
-
with gr.Tab("Text"
|
767 |
with gr.Row():
|
768 |
gr.Markdown("## Product Placement from Text")
|
769 |
with gr.Row():
|
770 |
with gr.Column():
|
771 |
with gr.Row():
|
772 |
input_fg = gr.Image(type="numpy", label="Image", height=480)
|
773 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
774 |
with gr.Group():
|
775 |
prompt = gr.Textbox(label="Prompt")
|
776 |
bg_source = gr.Radio(choices=[e.value for e in BGSource],
|
@@ -811,14 +958,27 @@ with block:
|
|
811 |
# run_on_click=True, examples_per_page=1024
|
812 |
# )
|
813 |
ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
|
814 |
-
relight_button.click(fn=process_relight, inputs=ips, outputs=[
|
815 |
example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
|
816 |
example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
|
817 |
-
|
818 |
-
|
819 |
-
|
|
|
|
|
820 |
|
|
|
|
|
|
|
|
|
821 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
822 |
with gr.Row():
|
823 |
gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
|
824 |
gr.Markdown("πΎ Generated images are automatically saved to 'outputs' folder")
|
@@ -937,11 +1097,11 @@ with block:
|
|
937 |
outputs=[extracted_fg, x_slider, y_slider]
|
938 |
)
|
939 |
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
|
946 |
get_depth_button.click(
|
947 |
fn=get_depth,
|
@@ -1101,5 +1261,5 @@ with block:
|
|
1101 |
)
|
1102 |
|
1103 |
|
1104 |
-
|
1105 |
block.launch(server_name='0.0.0.0', share=False)
|
|
|
|
|
|
1 |
import os
|
2 |
import math
|
3 |
import gradio as gr
|
|
|
15 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
16 |
from transformers import CLIPTextModel, CLIPTokenizer
|
17 |
from briarmbg import BriaRMBG
|
18 |
+
import dds_cloudapi_sdk
|
19 |
+
from dds_cloudapi_sdk import Config, Client, TextPrompt
|
20 |
+
from dds_cloudapi_sdk.tasks.dinox import DinoxTask
|
21 |
from enum import Enum
|
22 |
from torch.hub import download_url_to_file
|
23 |
+
import tempfile
|
24 |
|
25 |
+
from sam2.build_sam import build_sam2
|
26 |
+
|
27 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
28 |
import cv2
|
29 |
|
30 |
from typing import Optional
|
31 |
|
32 |
from Depth.depth_anything_v2.dpt import DepthAnythingV2
|
33 |
|
34 |
+
import httpx
|
35 |
+
|
36 |
+
client = httpx.Client(timeout=httpx.Timeout(10.0)) # Set timeout to 10 seconds
|
37 |
+
|
38 |
+
|
39 |
|
40 |
|
|
|
41 |
|
42 |
+
# from FLORENCE
|
43 |
+
import spaces
|
44 |
import supervision as sv
|
45 |
import torch
|
46 |
from PIL import Image
|
|
|
85 |
# Change UNet
|
86 |
|
87 |
with torch.no_grad():
|
88 |
+
new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
|
89 |
new_conv_in.weight.zero_()
|
90 |
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
|
91 |
new_conv_in.bias = unet.conv_in.bias
|
|
|
106 |
print(f"Xformers error: {e}")
|
107 |
print("Falling back to sliced attention")
|
108 |
# Use sliced attention for RTX 2070
|
109 |
+
unet.set_attention_slice_size(4)
|
110 |
+
vae.set_attention_slice_size(4)
|
111 |
unet.set_attn_processor(AttnProcessor2_0())
|
112 |
vae.set_attn_processor(AttnProcessor2_0())
|
113 |
else:
|
114 |
# Fallback for when xformers is not available
|
115 |
print("Using sliced attention")
|
116 |
+
unet.set_attention_slice_size(4)
|
117 |
+
vae.set_attention_slice_size(4)
|
118 |
unet.set_attn_processor(AttnProcessor2_0())
|
119 |
vae.set_attn_processor(AttnProcessor2_0())
|
120 |
|
|
|
140 |
|
141 |
# Load
|
142 |
|
143 |
+
model_path = './models/iclight_sd15_fc.safetensors'
|
144 |
+
#model_path = './models/iclight_sd15_fbc.safetensors'
|
145 |
|
146 |
|
147 |
+
# if not os.path.exists(model_path):
|
148 |
+
# download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
|
149 |
|
150 |
sd_offset = sf.load_file(model_path)
|
151 |
sd_origin = unet.state_dict()
|
|
|
234 |
image_encoder=None
|
235 |
)
|
236 |
|
237 |
+
|
238 |
@torch.inference_mode()
|
239 |
def encode_prompt_inner(txt: str):
|
240 |
max_length = tokenizer.model_max_length
|
|
|
255 |
|
256 |
return conds
|
257 |
|
258 |
+
|
259 |
@torch.inference_mode()
|
260 |
def encode_prompt_pair(positive_prompt, negative_prompt):
|
261 |
c = encode_prompt_inner(positive_prompt)
|
|
|
276 |
|
277 |
return c, uc
|
278 |
|
279 |
+
|
280 |
@torch.inference_mode()
|
281 |
def pytorch2numpy(imgs, quant=True):
|
282 |
results = []
|
|
|
293 |
results.append(y)
|
294 |
return results
|
295 |
|
296 |
+
|
297 |
@torch.inference_mode()
|
298 |
def numpy2pytorch(imgs):
|
299 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
|
|
|
321 |
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
322 |
return np.array(resized_image)
|
323 |
|
324 |
+
|
325 |
@torch.inference_mode()
|
326 |
def run_rmbg(img, sigma=0.0):
|
327 |
# Convert RGBA to RGB if needed
|
|
|
465 |
|
466 |
return pixels
|
467 |
|
|
|
468 |
@torch.inference_mode()
|
469 |
def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
|
470 |
clear_memory()
|
|
|
558 |
clear_memory()
|
559 |
return pixels, [fg, bg]
|
560 |
|
561 |
+
|
562 |
@torch.inference_mode()
|
563 |
def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
564 |
input_fg, matting = run_rmbg(input_fg)
|
|
|
566 |
return input_fg, results
|
567 |
|
568 |
|
569 |
+
|
570 |
@torch.inference_mode()
|
571 |
def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
|
572 |
bg_source = BGSource(bg_source)
|
|
|
770 |
compressed_img = np.array(Image.open("compressed_image.jpg"))
|
771 |
return compressed_img
|
772 |
|
773 |
+
@spaces.GPU(duration=60)
|
774 |
+
@torch.inference_mode()
|
775 |
+
def process_image(input_image, input_text):
|
776 |
+
"""Main processing function for the Gradio interface"""
|
777 |
+
|
778 |
+
# Initialize configs
|
779 |
+
API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
|
780 |
+
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
781 |
+
SAM2_MODEL_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs/sam2_hiera_l.yaml")
|
782 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
783 |
+
OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
|
784 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
785 |
+
|
786 |
+
# Initialize DDS client
|
787 |
+
config = Config(API_TOKEN)
|
788 |
+
client = Client(config)
|
789 |
+
|
790 |
+
# Process classes from text prompt
|
791 |
+
classes = [x.strip().lower() for x in input_text.split('.') if x]
|
792 |
+
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
793 |
+
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
794 |
+
|
795 |
+
# Save input image to temp file and get URL
|
796 |
+
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
797 |
+
cv2.imwrite(tmpfile.name, input_image)
|
798 |
+
image_url = client.upload_file(tmpfile.name)
|
799 |
+
os.remove(tmpfile.name)
|
800 |
+
|
801 |
+
# Run DINO-X detection
|
802 |
+
task = DinoxTask(
|
803 |
+
image_url=image_url,
|
804 |
+
prompts=[TextPrompt(text=input_text)]
|
805 |
+
)
|
806 |
+
client.run_task(task)
|
807 |
+
result = task.result
|
808 |
+
objects = result.objects
|
809 |
+
|
810 |
+
# Process detection results
|
811 |
+
input_boxes = []
|
812 |
+
confidences = []
|
813 |
+
class_names = []
|
814 |
+
class_ids = []
|
815 |
+
|
816 |
+
for obj in objects:
|
817 |
+
input_boxes.append(obj.bbox)
|
818 |
+
confidences.append(obj.score)
|
819 |
+
cls_name = obj.category.lower().strip()
|
820 |
+
class_names.append(cls_name)
|
821 |
+
class_ids.append(class_name_to_id[cls_name])
|
822 |
+
|
823 |
+
input_boxes = np.array(input_boxes)
|
824 |
+
class_ids = np.array(class_ids)
|
825 |
+
|
826 |
+
# Initialize SAM2
|
827 |
+
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
828 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
829 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
830 |
+
torch.backends.cudnn.allow_tf32 = True
|
831 |
+
|
832 |
+
sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
|
833 |
+
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
834 |
+
sam2_predictor.set_image(input_image)
|
835 |
+
|
836 |
+
# sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
|
837 |
+
|
838 |
+
|
839 |
+
# Get masks from SAM2
|
840 |
+
masks, scores, logits = sam2_predictor.predict(
|
841 |
+
point_coords=None,
|
842 |
+
point_labels=None,
|
843 |
+
box=input_boxes,
|
844 |
+
multimask_output=False,
|
845 |
+
)
|
846 |
+
if masks.ndim == 4:
|
847 |
+
masks = masks.squeeze(1)
|
848 |
+
|
849 |
+
# Create visualization
|
850 |
+
labels = [f"{class_name} {confidence:.2f}"
|
851 |
+
for class_name, confidence in zip(class_names, confidences)]
|
852 |
+
|
853 |
+
detections = sv.Detections(
|
854 |
+
xyxy=input_boxes,
|
855 |
+
mask=masks.astype(bool),
|
856 |
+
class_id=class_ids
|
857 |
+
)
|
858 |
+
|
859 |
+
box_annotator = sv.BoxAnnotator()
|
860 |
+
label_annotator = sv.LabelAnnotator()
|
861 |
+
mask_annotator = sv.MaskAnnotator()
|
862 |
+
|
863 |
+
annotated_frame = input_image.copy()
|
864 |
+
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
|
865 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
866 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
867 |
+
|
868 |
+
# Create transparent mask for first detected object
|
869 |
+
if len(detections) > 0:
|
870 |
+
# Get first mask
|
871 |
+
first_mask = detections.mask[0]
|
872 |
+
|
873 |
+
# Get original RGB image
|
874 |
+
img = input_image.copy()
|
875 |
+
H, W, C = img.shape
|
876 |
+
|
877 |
+
# Create RGBA image
|
878 |
+
alpha = np.zeros((H, W, 1), dtype=np.uint8)
|
879 |
+
alpha[first_mask] = 255
|
880 |
+
rgba = np.dstack((img, alpha)).astype(np.uint8)
|
881 |
+
|
882 |
+
# Crop to mask bounds to minimize image size
|
883 |
+
y_indices, x_indices = np.where(first_mask)
|
884 |
+
y_min, y_max = y_indices.min(), y_indices.max()
|
885 |
+
x_min, x_max = x_indices.min(), x_indices.max()
|
886 |
+
|
887 |
+
# Crop the RGBA image
|
888 |
+
cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
|
889 |
+
|
890 |
+
# Set extracted foreground for mask mover
|
891 |
+
mask_mover.set_extracted_fg(cropped_rgba)
|
892 |
+
|
893 |
+
return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
|
894 |
+
|
895 |
+
return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
|
896 |
+
|
897 |
|
898 |
block = gr.Blocks().queue()
|
899 |
with block:
|
900 |
+
with gr.Tab("Text"):
|
901 |
with gr.Row():
|
902 |
gr.Markdown("## Product Placement from Text")
|
903 |
with gr.Row():
|
904 |
with gr.Column():
|
905 |
with gr.Row():
|
906 |
input_fg = gr.Image(type="numpy", label="Image", height=480)
|
907 |
+
with gr.Row():
|
908 |
+
with gr.Group():
|
909 |
+
find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
|
910 |
+
text_prompt = gr.Textbox(
|
911 |
+
label="Text Prompt",
|
912 |
+
placeholder="Enter object classes separated by periods (e.g. 'car . person .')",
|
913 |
+
value="couch . table ."
|
914 |
+
)
|
915 |
+
extract_button = gr.Button(value="(Option 2) Remove Background")
|
916 |
+
with gr.Row():
|
917 |
+
extracted_objects = gr.Image(type="numpy", label="Extracted Foreground", height=480)
|
918 |
+
extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
|
919 |
+
|
920 |
+
# output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
|
921 |
with gr.Group():
|
922 |
prompt = gr.Textbox(label="Prompt")
|
923 |
bg_source = gr.Radio(choices=[e.value for e in BGSource],
|
|
|
958 |
# run_on_click=True, examples_per_page=1024
|
959 |
# )
|
960 |
ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
|
961 |
+
relight_button.click(fn=process_relight, inputs=ips, outputs=[extracted_fg, result_gallery])
|
962 |
example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
|
963 |
example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
|
964 |
+
find_objects_button.click(
|
965 |
+
fn=process_image,
|
966 |
+
inputs=[input_fg, text_prompt],
|
967 |
+
outputs=[extracted_objects, extracted_fg]
|
968 |
+
)
|
969 |
|
970 |
+
with gr.Tab("Background", visible=False):
|
971 |
+
# empty cache
|
972 |
+
|
973 |
+
mask_mover = MaskMover()
|
974 |
|
975 |
+
# with torch.no_grad():
|
976 |
+
# # Update the input channels to 12
|
977 |
+
# new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) # Changed from 8 to 12
|
978 |
+
# new_conv_in.weight.zero_()
|
979 |
+
# new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
|
980 |
+
# new_conv_in.bias = unet.conv_in.bias
|
981 |
+
# unet.conv_in = new_conv_in
|
982 |
with gr.Row():
|
983 |
gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
|
984 |
gr.Markdown("πΎ Generated images are automatically saved to 'outputs' folder")
|
|
|
1097 |
outputs=[extracted_fg, x_slider, y_slider]
|
1098 |
)
|
1099 |
|
1100 |
+
find_objects_button.click(
|
1101 |
+
fn=process_image,
|
1102 |
+
inputs=[input_image, text_prompt],
|
1103 |
+
outputs=[extracted_objects, extracted_fg, x_slider, y_slider]
|
1104 |
+
)
|
1105 |
|
1106 |
get_depth_button.click(
|
1107 |
fn=get_depth,
|
|
|
1261 |
)
|
1262 |
|
1263 |
|
|
|
1264 |
block.launch(server_name='0.0.0.0', share=False)
|
1265 |
+
|