add demo app code
Browse files- README.md +9 -5
- app.py +288 -0
- system_template.txt +32 -0
- user_template.txt +2 -0
README.md
CHANGED
@@ -1,13 +1,17 @@
|
|
1 |
---
|
2 |
title: OCTO
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: OCTO
|
3 |
+
emoji: π
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.14.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
# OCTO+: A Suite for Automatic Open-Vocabulary Object Placement in Mixed Reality
|
14 |
+
|
15 |
+
[](https://colab.research.google.com/github/octo-pearl/octo-pearl/blob/main/demo.ipynb) [](https://octo-pearl.github.io/) [](https://octo-pearl.github.io/)
|
16 |
+
|
17 |
+
This repo contains the code and data for the paper "[OCTO+: A Suite for Automatic Open-Vocabulary Object Placement in Mixed Reality](https://octo-pearl.github.io/)".
|
app.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
if not os.path.isdir("weights"):
|
4 |
+
os.mkdir("weights")
|
5 |
+
|
6 |
+
os.system("python -m pip install --upgrade pip")
|
7 |
+
os.system(
|
8 |
+
"wget https://raw.githubusercontent.com/asharma381/cs291I/main/backend/original_images/000749.png"
|
9 |
+
)
|
10 |
+
os.system(
|
11 |
+
"wget -q -O weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
12 |
+
)
|
13 |
+
os.system(
|
14 |
+
"wget -q -O weights/ram_plus_swin_large_14m.pth https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"
|
15 |
+
)
|
16 |
+
os.system(
|
17 |
+
"wget -q -O weights/groundingdino_swint_ogc.pth https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
|
18 |
+
)
|
19 |
+
os.system("pip install git+https://github.com/xinyu1205/recognize-anything.git")
|
20 |
+
os.system("pip install git+https://github.com/IDEA-Research/GroundingDINO.git")
|
21 |
+
os.system("pip install git+https://github.com/facebookresearch/segment-anything.git")
|
22 |
+
os.system("pip install openai==0.27.4")
|
23 |
+
os.system("pip install tenacity")
|
24 |
+
|
25 |
+
|
26 |
+
from typing import List, Tuple
|
27 |
+
|
28 |
+
import cv2
|
29 |
+
import gradio as gr
|
30 |
+
import groundingdino.config.GroundingDINO_SwinT_OGC
|
31 |
+
import numpy as np
|
32 |
+
import openai
|
33 |
+
import torch
|
34 |
+
from groundingdino.util.inference import Model
|
35 |
+
from PIL import Image, ImageDraw
|
36 |
+
from ram import get_transform
|
37 |
+
from ram import inference_ram as inference
|
38 |
+
from ram.models import ram_plus
|
39 |
+
from scipy.spatial.distance import cdist
|
40 |
+
from segment_anything import SamPredictor, sam_model_registry
|
41 |
+
from supervision import Detections
|
42 |
+
from tenacity import retry, wait_fixed
|
43 |
+
|
44 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
+
ram_model = None
|
46 |
+
ram_threshold_multiplier = 1
|
47 |
+
gdino_model = None
|
48 |
+
sam_model = None
|
49 |
+
sam_predictor = None
|
50 |
+
|
51 |
+
print("CUDA Available:", torch.cuda.is_available())
|
52 |
+
|
53 |
+
|
54 |
+
def get_tags_ram(
|
55 |
+
image: Image.Image, threshold_multiplier=0.8, weights_folder="weights"
|
56 |
+
) -> List[str]:
|
57 |
+
global ram_model, ram_threshold_multiplier
|
58 |
+
if ram_model is None:
|
59 |
+
print("Loading RAM++ Model...")
|
60 |
+
ram_model = ram_plus(
|
61 |
+
pretrained=f"{weights_folder}/ram_plus_swin_large_14m.pth",
|
62 |
+
vit="swin_l",
|
63 |
+
image_size=384,
|
64 |
+
)
|
65 |
+
ram_model.eval()
|
66 |
+
ram_model = ram_model.to(device)
|
67 |
+
|
68 |
+
ram_model.class_threshold *= threshold_multiplier / ram_threshold_multiplier
|
69 |
+
ram_threshold_multiplier = threshold_multiplier
|
70 |
+
transform = get_transform()
|
71 |
+
|
72 |
+
image = transform(image).unsqueeze(0).to(device)
|
73 |
+
res = inference(image, ram_model)
|
74 |
+
return [s.strip() for s in res[0].split("|")]
|
75 |
+
|
76 |
+
|
77 |
+
def get_gdino_result(
|
78 |
+
image: Image.Image,
|
79 |
+
classes: List[str],
|
80 |
+
box_threshold: float = 0.25,
|
81 |
+
weights_folder="weights",
|
82 |
+
) -> Tuple[Detections, List[str]]:
|
83 |
+
global gdino_model
|
84 |
+
|
85 |
+
if gdino_model is None:
|
86 |
+
print("Loading GroundingDINO Model...")
|
87 |
+
config_path = groundingdino.config.GroundingDINO_SwinT_OGC.__file__
|
88 |
+
gdino_model = Model(
|
89 |
+
model_config_path=config_path,
|
90 |
+
model_checkpoint_path=f"{weights_folder}/groundingdino_swint_ogc.pth",
|
91 |
+
device=device,
|
92 |
+
)
|
93 |
+
|
94 |
+
detections, phrases = gdino_model.predict_with_caption(
|
95 |
+
image=np.array(image),
|
96 |
+
caption=", ".join(classes),
|
97 |
+
box_threshold=box_threshold,
|
98 |
+
text_threshold=0.25,
|
99 |
+
)
|
100 |
+
|
101 |
+
return detections, phrases
|
102 |
+
|
103 |
+
|
104 |
+
def get_sam_model(weights_folder="weights"):
|
105 |
+
global sam_model
|
106 |
+
if sam_model is None:
|
107 |
+
sam_checkpoint = f"{weights_folder}/sam_vit_h_4b8939.pth"
|
108 |
+
sam_model = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
|
109 |
+
sam_model.to(device=device)
|
110 |
+
return sam_model
|
111 |
+
|
112 |
+
|
113 |
+
def filter_tags_gdino(image: Image.Image, tags: List[str]) -> List[str]:
|
114 |
+
detections, phrases = get_gdino_result(image, tags)
|
115 |
+
filtered_tags = []
|
116 |
+
for tag in tags:
|
117 |
+
for (
|
118 |
+
phrase,
|
119 |
+
area,
|
120 |
+
) in zip(phrases, detections.area):
|
121 |
+
if area < 0.9 * image.size[0] * image.size[1] and tag in phrase:
|
122 |
+
filtered_tags.append(tag)
|
123 |
+
break
|
124 |
+
return filtered_tags
|
125 |
+
|
126 |
+
|
127 |
+
def read_file_to_string(file_path: str) -> str:
|
128 |
+
content = ""
|
129 |
+
|
130 |
+
try:
|
131 |
+
with open(file_path, "r", encoding="utf8") as file:
|
132 |
+
content = file.read()
|
133 |
+
except FileNotFoundError:
|
134 |
+
print(f"The file {file_path} was not found.")
|
135 |
+
except Exception as e:
|
136 |
+
print(f"An error occurred while reading {file_path}: {e}")
|
137 |
+
|
138 |
+
return content
|
139 |
+
|
140 |
+
|
141 |
+
@retry(wait=wait_fixed(2))
|
142 |
+
def completion_with_backoff(**kwargs):
|
143 |
+
return openai.ChatCompletion.create(**kwargs)
|
144 |
+
|
145 |
+
|
146 |
+
def gpt4(
|
147 |
+
usr_prompt: str, sys_prompt: str = "", api_key: str = "", model: str = "gpt-4"
|
148 |
+
) -> str:
|
149 |
+
openai.api_key = api_key
|
150 |
+
|
151 |
+
message = [
|
152 |
+
{"role": "system", "content": sys_prompt},
|
153 |
+
{"role": "user", "content": usr_prompt},
|
154 |
+
]
|
155 |
+
|
156 |
+
response = completion_with_backoff(
|
157 |
+
model=model,
|
158 |
+
messages=message,
|
159 |
+
temperature=0.2,
|
160 |
+
max_tokens=1000,
|
161 |
+
frequency_penalty=0.0,
|
162 |
+
)
|
163 |
+
|
164 |
+
return response["choices"][0]["message"]["content"]
|
165 |
+
|
166 |
+
|
167 |
+
def select_best_tag(
|
168 |
+
filtered_tags: List[str], object_to_place: str, api_key: str = ""
|
169 |
+
) -> str:
|
170 |
+
user_template = read_file_to_string("user_template.txt").format(object=object_to_place)
|
171 |
+
user_prompt = user_template + "\n".join(filtered_tags)
|
172 |
+
system_prompt = read_file_to_string("system_template.txt")
|
173 |
+
return gpt4(user_prompt, system_prompt, api_key=api_key)
|
174 |
+
|
175 |
+
|
176 |
+
def get_location_gsam(
|
177 |
+
image: Image.Image, prompt: str, weights_folder="weights"
|
178 |
+
) -> Tuple[int, int]:
|
179 |
+
global sam_predictor
|
180 |
+
|
181 |
+
BOX_TRESHOLD = 0.25
|
182 |
+
RESIZE_RATIO = 3
|
183 |
+
|
184 |
+
detections, phrases = get_gdino_result(
|
185 |
+
image=image,
|
186 |
+
classes=[prompt],
|
187 |
+
box_threshold=BOX_TRESHOLD,
|
188 |
+
)
|
189 |
+
|
190 |
+
while len(detections.xyxy) == 0:
|
191 |
+
BOX_TRESHOLD -= 0.02
|
192 |
+
detections, phrases = get_gdino_result(
|
193 |
+
image=image,
|
194 |
+
classes=[prompt],
|
195 |
+
box_threshold=BOX_TRESHOLD,
|
196 |
+
)
|
197 |
+
|
198 |
+
sam_model = get_sam_model(weights_folder)
|
199 |
+
|
200 |
+
if sam_predictor is None:
|
201 |
+
print("Loading SAM Model...")
|
202 |
+
sam_predictor = SamPredictor(sam_model)
|
203 |
+
|
204 |
+
sam_predictor.set_image(np.array(image))
|
205 |
+
result_masks = []
|
206 |
+
for box in detections.xyxy:
|
207 |
+
masks, scores, logits = sam_predictor.predict(box=box, multimask_output=True)
|
208 |
+
index = np.argmax(scores)
|
209 |
+
result_masks.append(masks[index])
|
210 |
+
detections.mask = np.array(result_masks)
|
211 |
+
|
212 |
+
combined_mask = detections.mask[0]
|
213 |
+
for mask in detections.mask[1:]:
|
214 |
+
combined_mask += mask
|
215 |
+
combined_mask[combined_mask > 1] = 1
|
216 |
+
mask = cv2.resize(
|
217 |
+
combined_mask.astype("uint8"),
|
218 |
+
(
|
219 |
+
combined_mask.shape[1] // RESIZE_RATIO,
|
220 |
+
combined_mask.shape[0] // RESIZE_RATIO,
|
221 |
+
),
|
222 |
+
)
|
223 |
+
|
224 |
+
mask_2_pad = np.pad(mask, pad_width=2, mode="constant", constant_values=0)
|
225 |
+
mask_1_pad = np.pad(mask, pad_width=1, mode="constant", constant_values=0)
|
226 |
+
|
227 |
+
windows = np.lib.stride_tricks.sliding_window_view(mask_2_pad, (3, 3))
|
228 |
+
windows_all_zero = (windows == 0).all(axis=(2, 3))
|
229 |
+
|
230 |
+
result = np.where(windows_all_zero, 2, mask_1_pad)
|
231 |
+
mask_0_coordinates = np.argwhere(result == 0)
|
232 |
+
mask_1_coordinates = np.argwhere(result == 1)
|
233 |
+
distances = cdist(mask_1_coordinates, mask_0_coordinates, "euclidean")
|
234 |
+
max_min_distance_index = np.argmax(np.min(distances, axis=1))
|
235 |
+
y, x = mask_1_coordinates[max_min_distance_index]
|
236 |
+
|
237 |
+
return int(x) * RESIZE_RATIO, int(y) * RESIZE_RATIO
|
238 |
+
|
239 |
+
|
240 |
+
def run_octo_pipeline(input_image, object, api_key):
|
241 |
+
print("Inside run_octo_pipeline with input_image=", input_image, "object=", object)
|
242 |
+
|
243 |
+
print("Loading Image...")
|
244 |
+
image = input_image.convert("RGB")
|
245 |
+
|
246 |
+
print("Stage 1...")
|
247 |
+
tags = get_tags_ram(image, threshold_multiplier=0.8)
|
248 |
+
print("RAM++ Tags", tags)
|
249 |
+
filtered_tags = filter_tags_gdino(image, tags)
|
250 |
+
print("Filtered Tags", filtered_tags)
|
251 |
+
|
252 |
+
print("Stage 2...")
|
253 |
+
selected_tag = select_best_tag(filtered_tags, object, api_key=api_key)
|
254 |
+
print("GPT-4 Selected Tag", selected_tag)
|
255 |
+
|
256 |
+
print("Stage 3...")
|
257 |
+
x, y = get_location_gsam(image, selected_tag)
|
258 |
+
print("G-SAM Location", "(" + str(x) + "," + str(y) + ")")
|
259 |
+
|
260 |
+
draw = ImageDraw.Draw(image)
|
261 |
+
radius = 10
|
262 |
+
bbox = (x - radius, y - radius, x + radius, y + radius)
|
263 |
+
draw.ellipse(bbox, fill="red")
|
264 |
+
return [image]
|
265 |
+
|
266 |
+
|
267 |
+
block = gr.Blocks()
|
268 |
+
|
269 |
+
with block:
|
270 |
+
with gr.Row():
|
271 |
+
with gr.Column():
|
272 |
+
input_image = gr.Image(type="pil", value="000749.png")
|
273 |
+
object = gr.Textbox(label="Object", placeholder="Enter an object")
|
274 |
+
api_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter OpenAI API Key")
|
275 |
+
|
276 |
+
with gr.Column():
|
277 |
+
gallery = gr.Gallery(
|
278 |
+
label="Output",
|
279 |
+
show_label=False,
|
280 |
+
elem_id="gallery",
|
281 |
+
preview=True,
|
282 |
+
object_fit="scale-down",
|
283 |
+
)
|
284 |
+
|
285 |
+
iface = gr.Interface(
|
286 |
+
fn=run_octo_pipeline, inputs=[input_image, object, api_key], outputs=gallery
|
287 |
+
)
|
288 |
+
iface.launch()
|
system_template.txt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are an expert in determining where objects should be placed in a scene. You will be given a list of objects in a scene, and the name of a new object to be placed in the scene. Your task is to select the most natural location for the new object to be placed out of the options provided. Write one of your answers and write it exactly character for character as it appears in the list of possible answers. Provide a one-word response. Here are some examples.
|
2 |
+
Question: Where would be the most natural location for a banana to be placed?
|
3 |
+
Possible Answers:
|
4 |
+
floor
|
5 |
+
table
|
6 |
+
computer
|
7 |
+
sink
|
8 |
+
couch
|
9 |
+
|
10 |
+
table
|
11 |
+
|
12 |
+
Question: Where would be the most natural location for a marker to be placed?
|
13 |
+
Possible Answers:
|
14 |
+
bed
|
15 |
+
counter
|
16 |
+
computer
|
17 |
+
sink
|
18 |
+
desk
|
19 |
+
couch
|
20 |
+
|
21 |
+
desk
|
22 |
+
|
23 |
+
Question: Where would be the most natural location for a suitcase to be placed?
|
24 |
+
Possible Answers:
|
25 |
+
desk
|
26 |
+
floor
|
27 |
+
table
|
28 |
+
sink
|
29 |
+
computer
|
30 |
+
couch
|
31 |
+
|
32 |
+
floor
|
user_template.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Question: Where would be the most natural location for a {object} to be placed?
|
2 |
+
Possible Answers:
|