salma-remyx commited on
Commit
d11db70
·
2 Parent(s): 5b7dfe2 52d1bbb

Merge branch 'main' of https://huggingface.co/spaces/remyxai/SpaceMantis into main

Browse files
Files changed (2) hide show
  1. .gitattributes +38 -0
  2. local_app.py +469 -0
.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/ filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/depth_pro.pt filter=lfs diff=lfs merge=lfs -text
38
+ extra_deps/flash_attn-2.7.0.post2-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
local_app.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import uuid
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+ from PIL import Image
8
+ import open3d as o3d
9
+ import matplotlib.pyplot as plt
10
+
11
+ from transformers import AutoProcessor, AutoModelForCausalLM
12
+ from transformers import SamModel, SamProcessor
13
+
14
+ import depth_pro
15
+
16
+ import spacy
17
+ import gradio as gr
18
+
19
+ try:
20
+ nlp = spacy.load("en_core_web_sm")
21
+ except OSError:
22
+ # Download the model if it's not already available
23
+ from spacy.cli import download
24
+ download("en_core_web_sm")
25
+ nlp = spacy.load("en_core_web_sm")
26
+
27
+ def find_subject(doc):
28
+ for token in doc:
29
+ # Check if the token is a subject
30
+ if "subj" in token.dep_:
31
+ return token.text, token.head
32
+ return None, None
33
+
34
+ def extract_descriptions(doc, head):
35
+ descriptions = []
36
+ for chunk in doc.noun_chunks:
37
+ # Check if the chunk is directly related to the subject's verb or is an attribute
38
+ if chunk.root.head == head or chunk.root.dep_ == 'attr':
39
+ descriptions.append(chunk.text)
40
+ return descriptions
41
+
42
+ def caption_refiner(caption):
43
+ doc = nlp(caption)
44
+ subject, action_verb = find_subject(doc)
45
+ if action_verb:
46
+ descriptions = extract_descriptions(doc, action_verb)
47
+ return ', '.join(descriptions)
48
+ else:
49
+ return caption
50
+
51
+ def sam2(image, input_boxes, model_id="facebook/sam-vit-base"):
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+ model = SamModel.from_pretrained(model_id).to(device)
54
+ processor = SamProcessor.from_pretrained(model_id)
55
+ inputs = processor(image, input_boxes=[[input_boxes]], return_tensors="pt").to(device)
56
+ with torch.no_grad():
57
+ outputs = model(**inputs)
58
+
59
+ masks = processor.image_processor.post_process_masks(
60
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
61
+ )
62
+ return masks
63
+
64
+ def load_florence2(model_id="microsoft/Florence-2-base-ft", device='cuda'):
65
+ torch_dtype = torch.float16 if device == 'cuda' else torch.float32
66
+ florence_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, trust_remote_code=True).to(device)
67
+ florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
68
+ return florence_model, florence_processor
69
+
70
+ def florence2(image, prompt="", task="<OD>"):
71
+ device = florence_model.device
72
+ torch_dtype = florence_model.dtype
73
+ inputs = florence_processor(text=task + prompt, images=image, return_tensors="pt").to(device, torch_dtype)
74
+ generated_ids = florence_model.generate(
75
+ input_ids=inputs["input_ids"],
76
+ pixel_values=inputs["pixel_values"],
77
+ max_new_tokens=1024,
78
+ num_beams=3,
79
+ do_sample=False
80
+ )
81
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
82
+ parsed_answer = florence_processor.post_process_generation(generated_text, task=task, image_size=(image.width, image.height))
83
+ return parsed_answer[task]
84
+
85
+
86
+ def depth_estimation(image_path):
87
+ model.eval()
88
+ image, _, f_px = depth_pro.load_rgb(image_path)
89
+ image = transform(image)
90
+
91
+ # Run inference.
92
+ prediction = model.infer(image, f_px=f_px)
93
+ depth = prediction["depth"] # Depth in [m].
94
+ focallength_px = prediction["focallength_px"] # Focal length in pixels.
95
+ depth = depth.cpu().numpy()
96
+ return depth, focallength_px
97
+
98
+
99
+ def create_point_cloud_from_rgbd(rgb, depth, intrinsic_parameters):
100
+ rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
101
+ o3d.geometry.Image(rgb),
102
+ o3d.geometry.Image(depth),
103
+ depth_scale=10.0,
104
+ depth_trunc=100.0,
105
+ convert_rgb_to_intensity=False
106
+ )
107
+ intrinsic = o3d.camera.PinholeCameraIntrinsic()
108
+ intrinsic.set_intrinsics(intrinsic_parameters['width'], intrinsic_parameters['height'],
109
+ intrinsic_parameters['fx'], intrinsic_parameters['fy'],
110
+ intrinsic_parameters['cx'], intrinsic_parameters['cy'])
111
+ pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, intrinsic)
112
+ return pcd
113
+
114
+
115
+ def canonicalize_point_cloud(pcd, canonicalize_threshold=0.3):
116
+ # Segment the largest plane, assumed to be the floor
117
+ plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
118
+
119
+ canonicalized = False
120
+ if len(inliers) / len(pcd.points) > canonicalize_threshold:
121
+ canonicalized = True
122
+
123
+ # Ensure the plane normal points upwards
124
+ if np.dot(plane_model[:3], [0, 1, 0]) < 0:
125
+ plane_model = -plane_model
126
+
127
+ # Normalize the plane normal vector
128
+ normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
129
+
130
+ # Compute the new basis vectors
131
+ new_y = normal
132
+ new_x = np.cross(new_y, [0, 0, -1])
133
+ new_x /= np.linalg.norm(new_x)
134
+ new_z = np.cross(new_x, new_y)
135
+
136
+ # Create the transformation matrix
137
+ transformation = np.identity(4)
138
+ transformation[:3, :3] = np.vstack((new_x, new_y, new_z)).T
139
+ transformation[:3, 3] = -np.dot(transformation[:3, :3], pcd.points[inliers[0]])
140
+
141
+
142
+ # Apply the transformation
143
+ pcd.transform(transformation)
144
+
145
+ # Additional 180-degree rotation around the Z-axis
146
+ rotation_z_180 = np.array([[np.cos(np.pi), -np.sin(np.pi), 0],
147
+ [np.sin(np.pi), np.cos(np.pi), 0],
148
+ [0, 0, 1]])
149
+ pcd.rotate(rotation_z_180, center=(0, 0, 0))
150
+
151
+ return pcd, canonicalized, transformation
152
+ else:
153
+ return pcd, canonicalized, None
154
+
155
+
156
+ def compute_iou(box1, box2):
157
+ # Extract the coordinates
158
+ x1_min, y1_min, x1_max, y1_max = box1
159
+ x2_min, y2_min, x2_max, y2_max = box2
160
+
161
+ # Compute the intersection rectangle
162
+ x_inter_min = max(x1_min, x2_min)
163
+ y_inter_min = max(y1_min, y2_min)
164
+ x_inter_max = min(x1_max, x2_max)
165
+ y_inter_max = min(y1_max, y2_max)
166
+
167
+ # Intersection width and height
168
+ inter_width = max(0, x_inter_max - x_inter_min)
169
+ inter_height = max(0, y_inter_max - y_inter_min)
170
+
171
+ # Intersection area
172
+ inter_area = inter_width * inter_height
173
+
174
+ # Boxes areas
175
+ box1_area = (x1_max - x1_min) * (y1_max - y1_min)
176
+ box2_area = (x2_max - x2_min) * (y2_max - y2_min)
177
+
178
+ # Union area
179
+ union_area = box1_area + box2_area - inter_area
180
+
181
+ # Intersection over Union
182
+ iou = inter_area / union_area if union_area != 0 else 0
183
+
184
+ return iou
185
+
186
+
187
+ def human_like_distance(distance_meters, scale_factor=10):
188
+ # Define the choices with units included, focusing on the 0.1 to 10 meters range
189
+ distance_meters *= scale_factor
190
+ if distance_meters < 1: # For distances less than 1 meter
191
+ choices = [
192
+ (
193
+ round(distance_meters * 100, 2),
194
+ "centimeters",
195
+ 0.2,
196
+ ), # Centimeters for very small distances
197
+ (
198
+ round(distance_meters, 2),
199
+ "inches",
200
+ 0.8,
201
+ ), # Inches for the majority of cases under 1 meter
202
+ ]
203
+ elif distance_meters < 3: # For distances less than 3 meters
204
+ choices = [
205
+ (round(distance_meters, 2), "meters", 0.5),
206
+ (
207
+ round(distance_meters, 2),
208
+ "feet",
209
+ 0.5,
210
+ ), # Feet as a common unit within indoor spaces
211
+ ]
212
+ else: # For distances from 3 up to 10 meters
213
+ choices = [
214
+ (
215
+ round(distance_meters, 2),
216
+ "meters",
217
+ 0.7,
218
+ ), # Meters for clarity and international understanding
219
+ (
220
+ round(distance_meters, 2),
221
+ "feet",
222
+ 0.3,
223
+ ), # Feet for additional context
224
+ ]
225
+ # Normalize probabilities and make a selection
226
+ total_probability = sum(prob for _, _, prob in choices)
227
+ cumulative_distribution = []
228
+ cumulative_sum = 0
229
+ for value, unit, probability in choices:
230
+ cumulative_sum += probability / total_probability # Normalize probabilities
231
+ cumulative_distribution.append((cumulative_sum, value, unit))
232
+
233
+ # Randomly choose based on the cumulative distribution
234
+ r = random.random()
235
+ for cumulative_prob, value, unit in cumulative_distribution:
236
+ if r < cumulative_prob:
237
+ return f"{value} {unit}"
238
+
239
+ # Fallback to the last choice if something goes wrong
240
+ return f"{choices[-1][0]} {choices[-1][1]}"
241
+
242
+
243
+ def filter_bboxes(data, iou_threshold=0.5):
244
+ filtered_bboxes = []
245
+ filtered_labels = []
246
+
247
+ for i in range(len(data['bboxes'])):
248
+ current_box = data['bboxes'][i]
249
+ current_label = data['labels'][i]
250
+ is_duplicate = False
251
+
252
+ for j in range(len(filtered_bboxes)):
253
+ if current_label == filtered_labels[j]:# and compute_iou(current_box, filtered_bboxes[j]) > iou_threshold:
254
+ is_duplicate = True
255
+ break
256
+
257
+ if not is_duplicate:
258
+ filtered_bboxes.append(current_box)
259
+ filtered_labels.append(current_label)
260
+
261
+ return {'bboxes': filtered_bboxes, 'labels': filtered_labels, 'caption': data['caption']}
262
+
263
+ def process_image(image_path: str):
264
+ depth, fx = depth_estimation(image_path)
265
+
266
+ img = Image.open(image_path).convert('RGB')
267
+ width, height = img.size
268
+
269
+ description = florence2(img, task="<MORE_DETAILED_CAPTION>")
270
+ print(description)
271
+
272
+ regions = []
273
+ for cap in description.split('.'):
274
+ if cap:
275
+ roi = florence2(img, prompt=" " + cap, task="<CAPTION_TO_PHRASE_GROUNDING>")
276
+ roi["caption"] = caption_refiner(cap.lower())
277
+ roi = filter_bboxes(roi)
278
+ if len(roi['bboxes']) > 1:
279
+ flip = random.choice(['heads', 'tails'])
280
+ if flip == 'heads':
281
+ idx = random.randint(1, len(roi['bboxes']) - 1)
282
+ else:
283
+ idx = 0
284
+ if idx > 0: # test bbox IOU
285
+ roi['caption'] = roi['labels'][idx].lower() + ' with ' + roi['labels'][0].lower()
286
+ roi['bboxes'] = [roi['bboxes'][idx]]
287
+ roi['labels'] = [roi['labels'][idx]]
288
+
289
+ if roi['bboxes']:
290
+ regions.append(roi)
291
+ print(roi)
292
+
293
+ bboxes = [item['bboxes'][0] for item in regions]
294
+ n = len(bboxes)
295
+ distance_matrix = np.zeros((n, n))
296
+ for i in range(n):
297
+ for j in range(n):
298
+ if i != j:
299
+ distance_matrix[i][j] = 1 - compute_iou(bboxes[i], bboxes[j])
300
+
301
+ scores = np.sum(distance_matrix, axis=1)
302
+ selected_indices = np.argsort(scores)[-3:]
303
+ regions = [(regions[i]['bboxes'][0], regions[i]['caption']) for i in selected_indices][:2]
304
+
305
+ # Create point cloud
306
+ camera_intrinsics = intrinsic_parameters = {
307
+ 'width': width,
308
+ 'height': height,
309
+ 'fx': fx,
310
+ 'fy': fx * height / width,
311
+ 'cx': width / 2,
312
+ 'cy': height / 2,
313
+ }
314
+
315
+ pcd = create_point_cloud_from_rgbd(np.array(img).copy(), depth, camera_intrinsics)
316
+ normed_pcd, canonicalized, transformation = canonicalize_point_cloud(pcd)
317
+
318
+
319
+ masks = []
320
+ for box, cap in regions:
321
+ masks.append((cap, sam2(img, box)))
322
+
323
+
324
+ point_clouds = []
325
+ for cap, mask in masks:
326
+ m = mask[0].numpy()[0].squeeze().transpose((1, 2, 0))
327
+ mask = np.any(m, axis=2)
328
+
329
+ try:
330
+ points = np.asarray(normed_pcd.points)
331
+ colors = np.asarray(normed_pcd.colors)
332
+ masked_points = points[mask.ravel()]
333
+ masked_colors = colors[mask.ravel()]
334
+
335
+ masked_point_cloud = o3d.geometry.PointCloud()
336
+ masked_point_cloud.points = o3d.utility.Vector3dVector(masked_points)
337
+ masked_point_cloud.colors = o3d.utility.Vector3dVector(masked_colors)
338
+
339
+ point_clouds.append((cap, masked_point_cloud))
340
+ except:
341
+ pass
342
+
343
+ boxes3D = []
344
+ centers = []
345
+ pcd = o3d.geometry.PointCloud()
346
+ for cap, pc in point_clouds[:2]:
347
+ cl, ind = pc.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
348
+ inlier_cloud = pc.select_by_index(ind)
349
+ pcd += inlier_cloud
350
+ obb = inlier_cloud.get_axis_aligned_bounding_box()
351
+ obb.color = (1, 0, 0)
352
+ centers.append(obb.get_center())
353
+ boxes3D.append(obb)
354
+
355
+
356
+ lines = [[0, 1]]
357
+ points = [centers[0], centers[1]]
358
+ distance = human_like_distance(np.asarray(point_clouds[0][1].compute_point_cloud_distance(point_clouds[-1][1])).mean())
359
+ text_output = "Distance between {} and {} is: {}".format(point_clouds[0][0], point_clouds[-1][0], distance)
360
+ print(text_output)
361
+
362
+ colors = [[1, 0, 0] for i in range(len(lines))] # Red color for lines
363
+ line_set = o3d.geometry.LineSet(
364
+ points=o3d.utility.Vector3dVector(points),
365
+ lines=o3d.utility.Vector2iVector(lines)
366
+ )
367
+ line_set.colors = o3d.utility.Vector3dVector(colors)
368
+
369
+ boxes3D.append(line_set)
370
+
371
+
372
+ uuid_out = str(uuid.uuid4())
373
+ ply_file = f"output_{uuid_out}.ply"
374
+ obj_file = f"output_{uuid_out}.obj"
375
+ o3d.io.write_point_cloud(ply_file, pcd)
376
+
377
+ mesh = o3d.io.read_triangle_mesh(ply_file)
378
+
379
+ o3d.io.write_triangle_mesh(obj_file, mesh)
380
+
381
+ return obj_file, text_output
382
+
383
+
384
+
385
+ def custom_draw_geometry_with_rotation(pcd):
386
+
387
+ def rotate_view(vis):
388
+ ctr = vis.get_view_control()
389
+ vis.get_render_option().background_color = [0, 0, 0]
390
+ ctr.rotate(1.0, 0.0)
391
+ # https://github.com/isl-org/Open3D/issues/1483
392
+ #parameters = o3d.io.read_pinhole_camera_parameters("ScreenCamera_2024-10-24-10-03-57.json")
393
+ #ctr.convert_from_pinhole_camera_parameters(parameters)
394
+ return False
395
+
396
+ o3d.visualization.draw_geometries_with_animation_callback([pcd] + boxes3D,
397
+ rotate_view)
398
+
399
+
400
+ def build_demo():
401
+ with gr.Blocks() as demo:
402
+ # Title and introductory Markdown
403
+ gr.Markdown("""
404
+ # Synthesizing SpatialVQA Samples with VQASynth
405
+ This space helps test the full [VQASynth](https://github.com/remyxai/VQASynth) scene reconstruction pipeline on a single image with visualizations.
406
+
407
+ ### [Github](https://github.com/remyxai/VQASynth) | [Collection](https://huggingface.co/collections/remyxai/spacevlms-66a3dbb924756d98e7aec678)
408
+ """)
409
+
410
+ # Description for users
411
+ gr.Markdown("""
412
+ ## Instructions
413
+ Upload an image, and the tool will generate a corresponding 3D point cloud visualization of the objects found and an example prompt and response describing a spatial relationship between the objects.
414
+ """)
415
+
416
+ with gr.Row():
417
+ # Left Column: Inputs
418
+ with gr.Column():
419
+ # Image upload and processing button in the left column
420
+ image_input = gr.Image(type="filepath", label="Upload an Image")
421
+ generate_button = gr.Button("Generate")
422
+
423
+ # Right Column: Outputs
424
+ with gr.Column():
425
+ # 3D Model and Caption Outputs
426
+ model_output = gr.Model3D(label="3D Point Cloud") # Only used as output
427
+ caption_output = gr.Text(label="Caption")
428
+
429
+ # Link the button to process the image and display the outputs
430
+ generate_button.click(
431
+ process_image, # Your processing function
432
+ inputs=image_input,
433
+ outputs=[model_output, caption_output]
434
+ )
435
+
436
+ # Examples section at the bottom
437
+ gr.Examples(
438
+ examples=[
439
+ ["./examples/warehouse_rgb.jpg"], ["./examples/spooky_doggy.png"], ["./examples/bee_and_flower.jpg"], ["./examples/road-through-dense-forest.jpg"], ["./examples/gears.png"] # Update with the path to your example image
440
+ ],
441
+ inputs=image_input,
442
+ label="Example Images",
443
+ examples_per_page=5
444
+ )
445
+
446
+ # Citations
447
+ gr.Markdown("""
448
+ ## Citation
449
+ ```
450
+ @article{chen2024spatialvlm,
451
+ title = {SpatialVLM: Endowing Vision-Language Models with Spatial Reasoning Capabilities},
452
+ author = {Chen, Boyuan and Xu, Zhuo and Kirmani, Sean and Ichter, Brian and Driess, Danny and Florence, Pete and Sadigh, Dorsa and Guibas, Leonidas and Xia, Fei},
453
+ journal = {arXiv preprint arXiv:2401.12168},
454
+ year = {2024},
455
+ url = {https://arxiv.org/abs/2401.12168},
456
+ }
457
+ ```
458
+ """)
459
+
460
+ return demo
461
+
462
+ if __name__ == "__main__":
463
+ global model, transform, florence_model, florence_processor
464
+ model, transform = depth_pro.create_model_and_transforms(device='cuda')
465
+ florence_model, florence_processor = load_florence2(device='cuda')
466
+
467
+
468
+ demo = build_demo()
469
+ demo.launch(share=True)