ziqima commited on
Commit
5764d89
·
1 Parent(s): 07e191e

fix open3d zerogpu

Browse files
Files changed (3) hide show
  1. app.py +10 -4
  2. inference/inference.py +3 -3
  3. open3d_zerogpu_fix.py +7 -0
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import gradio as gr
 
 
2
  import re
3
  from utils import read_pcd, render_point_cloud, render_pcd_file, set_seed, install_cuda_toolkit
4
  from inference.utils import get_legend
@@ -43,6 +45,10 @@ source_dict = {
43
  "chair": "wild"
44
  }
45
 
 
 
 
 
46
  def predict(pcd_path, inference_mode, part_queries):
47
  set_seed()
48
  xyz, rgb, normal = read_pcd(pcd_path)
@@ -52,14 +58,14 @@ def predict(pcd_path, inference_mode, part_queries):
52
  raise gr.Error("For segmentation mode, please provide 2 or more parts", duration=5)
53
  seg_rgb = segment_obj(xyz, rgb, normal, parts).cpu().numpy()
54
  legend = get_legend(parts)
55
- return render_point_cloud(xyz, seg_rgb, legend=legend)
56
  elif inference_mode == "Localization":
57
  if "," in part_queries or ";" in part_queries or "." in part_queries:
58
  raise gr.Error("For localization mode, please provide only one part", duration=5)
59
  heatmap_rgb = get_heatmap(xyz, rgb, normal, part_queries).cpu().numpy()
60
- return render_point_cloud(xyz, heatmap_rgb)
61
  else:
62
- return None
63
 
64
  def on_select(evt: gr.SelectData):
65
  obj_name = evt.value['image']['orig_name'][:-4]
@@ -150,7 +156,7 @@ with gr.Blocks(theme=gr.themes.Default(text_size="lg", radius_size="none")) as d
150
  outputs=[input_point_cloud],
151
  )
152
  run_button.click(
153
- fn=predict,
154
  inputs=[file_upload, inference_mode, part_queries],
155
  outputs=[output_point_cloud],
156
  )
 
1
  import gradio as gr
2
+ import open3d_zerogpu_fix
3
+ import spaces
4
  import re
5
  from utils import read_pcd, render_point_cloud, render_pcd_file, set_seed, install_cuda_toolkit
6
  from inference.utils import get_legend
 
45
  "chair": "wild"
46
  }
47
 
48
+ @spaces.GPU(duration=120)
49
+ def run_predict(*args):
50
+ yield from predict(*args)
51
+
52
  def predict(pcd_path, inference_mode, part_queries):
53
  set_seed()
54
  xyz, rgb, normal = read_pcd(pcd_path)
 
58
  raise gr.Error("For segmentation mode, please provide 2 or more parts", duration=5)
59
  seg_rgb = segment_obj(xyz, rgb, normal, parts).cpu().numpy()
60
  legend = get_legend(parts)
61
+ yield render_point_cloud(xyz, seg_rgb, legend=legend)
62
  elif inference_mode == "Localization":
63
  if "," in part_queries or ";" in part_queries or "." in part_queries:
64
  raise gr.Error("For localization mode, please provide only one part", duration=5)
65
  heatmap_rgb = get_heatmap(xyz, rgb, normal, part_queries).cpu().numpy()
66
+ yield render_point_cloud(xyz, heatmap_rgb)
67
  else:
68
+ yield None
69
 
70
  def on_select(evt: gr.SelectData):
71
  obj_name = evt.value['image']['orig_name'][:-4]
 
156
  outputs=[input_point_cloud],
157
  )
158
  run_button.click(
159
+ fn=run_predict,
160
  inputs=[file_upload, inference_mode, part_queries],
161
  outputs=[output_point_cloud],
162
  )
inference/inference.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text
5
- import spaces
6
 
7
  DEVICE = "cuda:0"
8
  #if torch.cuda.is_available():
@@ -77,7 +77,7 @@ def get_heatmap_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have ba
77
  heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze()
78
  return heatmap_rgb
79
 
80
- @spaces.GPU
81
  def segment_obj(xyz, rgb, normal, queries):
82
  model = load_model()
83
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
@@ -85,7 +85,7 @@ def segment_obj(xyz, rgb, normal, queries):
85
  seg_rgb = get_segmentation_rgb(model, data_dict)
86
  return seg_rgb
87
 
88
- @spaces.GPU
89
  def get_heatmap(xyz, rgb, normal, query):
90
  model = load_model()
91
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
 
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text
5
+ #import spaces
6
 
7
  DEVICE = "cuda:0"
8
  #if torch.cuda.is_available():
 
77
  heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze()
78
  return heatmap_rgb
79
 
80
+ #@spaces.GPU
81
  def segment_obj(xyz, rgb, normal, queries):
82
  model = load_model()
83
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
 
85
  seg_rgb = get_segmentation_rgb(model, data_dict)
86
  return seg_rgb
87
 
88
+ #@spaces.GPU
89
  def get_heatmap(xyz, rgb, normal, query):
90
  model = load_model()
91
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
open3d_zerogpu_fix.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import fileinput
2
+ import site
3
+ from pathlib import Path
4
+
5
+ with fileinput.FileInput(f'{site.getsitepackages()[0]}/open3d/__init__.py', inplace=True) as file:
6
+ for line in file:
7
+ print(line.replace('_pybind_cuda.open3d_core_cuda_device_count()', '1'), end='')