Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image as PILImage
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
9 |
+
from qwen_vl_utils import process_vision_info
|
10 |
+
import re
|
11 |
+
import io
|
12 |
+
import base64
|
13 |
+
import cv2
|
14 |
+
from typing import List, Tuple, Optional
|
15 |
+
|
16 |
+
# ------------------ 安装SAM2依赖 ------------------
|
17 |
+
def install_sam2():
|
18 |
+
"""检查并安装SAM2及其依赖"""
|
19 |
+
sam2_dir = "third_party/sam2"
|
20 |
+
if not os.path.exists(sam2_dir):
|
21 |
+
print("Installing SAM2...")
|
22 |
+
os.makedirs("third_party", exist_ok=True)
|
23 |
+
subprocess.run(["git", "clone", "https://github.com/facebookresearch/sam2.git", sam2_dir], check=True)
|
24 |
+
|
25 |
+
# 安装依赖
|
26 |
+
subprocess.run(["pip", "install", "-e", sam2_dir], check=True)
|
27 |
+
|
28 |
+
# 下载检查点
|
29 |
+
checkpoints_dir = os.path.join(sam2_dir, "checkpoints")
|
30 |
+
os.makedirs(checkpoints_dir, exist_ok=True)
|
31 |
+
subprocess.run([os.path.join(checkpoints_dir, "download_ckpts.sh")], cwd=checkpoints_dir, shell=True, check=True)
|
32 |
+
print("SAM2 installed successfully!")
|
33 |
+
else:
|
34 |
+
print("SAM2 already installed.")
|
35 |
+
|
36 |
+
# 确保安装SAM2
|
37 |
+
install_sam2()
|
38 |
+
|
39 |
+
# ------------------ 初始化模型 ------------------
|
40 |
+
# 使用相对路径
|
41 |
+
MODEL_PATH = "geshang/Seg-R1-COD"
|
42 |
+
SAM_CHECKPOINT = "third_party/sam2/checkpoints/sam2.1_hiera_large.pt"
|
43 |
+
|
44 |
+
# 自动检测设备
|
45 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
46 |
+
RESIZE_SIZE = (768, 768)
|
47 |
+
|
48 |
+
# 加载Qwen模型
|
49 |
+
try:
|
50 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
51 |
+
MODEL_PATH,
|
52 |
+
torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
|
53 |
+
device_map="auto" if DEVICE == "cuda" else None
|
54 |
+
).to(DEVICE)
|
55 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
|
56 |
+
print(f"Qwen model loaded on {DEVICE}")
|
57 |
+
except Exception as e:
|
58 |
+
print(f"Error loading Qwen model: {e}")
|
59 |
+
# 创建虚拟模型以便继续运行
|
60 |
+
model = None
|
61 |
+
processor = None
|
62 |
+
|
63 |
+
# SAM Wrapper
|
64 |
+
class CustomSAMWrapper:
|
65 |
+
def __init__(self, model_path: str, device: str = DEVICE):
|
66 |
+
try:
|
67 |
+
from sam2.build_sam import build_sam2
|
68 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
69 |
+
|
70 |
+
self.device = torch.device(device)
|
71 |
+
model_cfg = os.path.join("third_party/sam2", "configs/sam2.1/sam2.1_hiera_l.yaml")
|
72 |
+
sam_model = build_sam2(model_cfg, model_path)
|
73 |
+
sam_model = sam_model.to(self.device)
|
74 |
+
self.predictor = SAM2ImagePredictor(sam_model)
|
75 |
+
self.last_mask = None
|
76 |
+
print(f"SAM model loaded on {device}")
|
77 |
+
except Exception as e:
|
78 |
+
print(f"Error loading SAM model: {e}")
|
79 |
+
self.predictor = None
|
80 |
+
|
81 |
+
def predict(self, image: PILImage.Image,
|
82 |
+
points: List[Tuple[int, int]],
|
83 |
+
labels: List[int],
|
84 |
+
bbox: Optional[List[List[int]]] = None) -> Tuple[np.ndarray, float]:
|
85 |
+
if not self.predictor:
|
86 |
+
return np.zeros((image.height, image.width), dtype=bool), 0.0
|
87 |
+
|
88 |
+
try:
|
89 |
+
input_points = np.array(points) if points else None
|
90 |
+
input_labels = np.array(labels) if labels else None
|
91 |
+
input_bboxes = np.array(bbox) if bbox else None
|
92 |
+
|
93 |
+
image_np = np.array(image)
|
94 |
+
rgb_image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
95 |
+
|
96 |
+
self.predictor.set_image(rgb_image)
|
97 |
+
|
98 |
+
mask_pred, score, logits = self.predictor.predict(
|
99 |
+
point_coords=input_points,
|
100 |
+
point_labels=input_labels,
|
101 |
+
box=input_bboxes,
|
102 |
+
multimask_output=False,
|
103 |
+
)
|
104 |
+
|
105 |
+
self.last_mask = mask_pred[0]
|
106 |
+
return mask_pred[0], score[0]
|
107 |
+
except Exception as e:
|
108 |
+
print(f"SAM prediction error: {e}")
|
109 |
+
return np.zeros((image.height, image.width), dtype=bool), 0.0
|
110 |
+
|
111 |
+
# 初始化SAM包装器
|
112 |
+
sam_wrapper = CustomSAMWrapper(SAM_CHECKPOINT, device=DEVICE)
|
113 |
+
|
114 |
+
# ------------------ 推理相关函数 ------------------
|
115 |
+
|
116 |
+
def parse_custom_format(content: str):
|
117 |
+
point_pattern = r"<points>\s*(\[\s*(?:\[\s*\d+\s*,\s*\d+\s*\]\s*,?\s*)+\])\s*</points>"
|
118 |
+
label_pattern = r"<labels>\s*(\[\s*(?:\d+\s*,?\s*)+\])\s*</labels>"
|
119 |
+
bbox_pattern = r"<bbox>\s*(\[\s*\d+\s*,\s*\d+\s*,\s*\d+\s*,\s*\d+\s*\])\s*</bbox>"
|
120 |
+
|
121 |
+
point_match = re.search(point_pattern, content)
|
122 |
+
label_match = re.search(label_pattern, content)
|
123 |
+
bbox_matches = re.findall(bbox_pattern, content)
|
124 |
+
|
125 |
+
try:
|
126 |
+
points = np.array(eval(point_match.group(1))) if point_match else None
|
127 |
+
labels = np.array(eval(label_match.group(1))) if label_match else None
|
128 |
+
|
129 |
+
if points is not None and labels is not None:
|
130 |
+
if not (len(points.shape) == 2 and points.shape[1] == 2 and len(labels) == points.shape[0]):
|
131 |
+
points, labels = None, None
|
132 |
+
|
133 |
+
bboxes = []
|
134 |
+
for bbox_str in bbox_matches:
|
135 |
+
bbox = np.array(eval(bbox_str))
|
136 |
+
if len(bbox.shape) == 1 and bbox.shape[0] == 4:
|
137 |
+
bboxes.append(bbox)
|
138 |
+
|
139 |
+
bboxes = np.stack(bboxes, axis=0) if bboxes else None
|
140 |
+
|
141 |
+
return points, labels, bboxes
|
142 |
+
|
143 |
+
except Exception as e:
|
144 |
+
print("Error parsing content:", e)
|
145 |
+
return None, None, None
|
146 |
+
|
147 |
+
def prepare_test_messages(image, prompt):
|
148 |
+
buffered = io.BytesIO()
|
149 |
+
image = TF.resize(image, RESIZE_SIZE)
|
150 |
+
image.save(buffered, format="JPEG")
|
151 |
+
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
152 |
+
|
153 |
+
if "segment" in prompt or "mask" in prompt:
|
154 |
+
SYSTEM_PROMPT = (
|
155 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
156 |
+
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
157 |
+
"process should enclosed within <think> </think> tags, and the bounding box, points and points labels should be enclosed within <bbox></bbox>, <points></points>, and <labels></labels>, respectively. i.e., "
|
158 |
+
"<think> reasoning process here </think> <bbox>[x1,y1,x2,y2]</bbox>, <points>[[x3,y3],[x4,y4],...]</points>, <labels>[1,0,...]</labels>"
|
159 |
+
"Where 1 indicates a foreground (object) point, and 0 indicates a background point."
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
SYSTEM_PROMPT = "You're a helpful visual assistant."
|
163 |
+
|
164 |
+
messages = [
|
165 |
+
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
166 |
+
{
|
167 |
+
"role": "user",
|
168 |
+
"content": [
|
169 |
+
{"type": "image", "image": f"data:image/jpeg;base64,{img_base64}"},
|
170 |
+
{"type": "text", "text": prompt},
|
171 |
+
],
|
172 |
+
},
|
173 |
+
]
|
174 |
+
return [messages]
|
175 |
+
|
176 |
+
def answer_question(batch_messages):
|
177 |
+
if not model or not processor:
|
178 |
+
return ["Model not loaded. Please check logs."]
|
179 |
+
|
180 |
+
try:
|
181 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
182 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
183 |
+
inputs = processor(text=text, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True).to(DEVICE)
|
184 |
+
outputs = model.generate(**inputs, use_cache=True, max_new_tokens=1024)
|
185 |
+
trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, outputs)]
|
186 |
+
return processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
187 |
+
except Exception as e:
|
188 |
+
print(f"Error generating answer: {e}")
|
189 |
+
return ["Error generating response"]
|
190 |
+
|
191 |
+
def visualize_masks_on_image(
|
192 |
+
image: PILImage.Image,
|
193 |
+
masks_np: list,
|
194 |
+
colors=[(255, 0, 0), (0, 255, 0), (0, 0, 255),
|
195 |
+
(255, 255, 0), (255, 0, 255), (0, 255, 255),
|
196 |
+
(128, 128, 255)],
|
197 |
+
alpha=0.5,
|
198 |
+
):
|
199 |
+
if not masks_np:
|
200 |
+
return image
|
201 |
+
|
202 |
+
image_np = np.array(image)
|
203 |
+
color_mask = np.zeros((image_np.shape[0], image_np.shape[1], 3), dtype=np.uint8)
|
204 |
+
|
205 |
+
for i, mask in enumerate(masks_np):
|
206 |
+
color = colors[i % len(colors)]
|
207 |
+
mask = mask.astype(np.uint8)
|
208 |
+
|
209 |
+
if mask.shape[:2] != image_np.shape[:2]:
|
210 |
+
mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]))
|
211 |
+
|
212 |
+
color_mask[:, :, 0] = color_mask[:, :, 0] | (mask * color[0])
|
213 |
+
color_mask[:, :, 1] = color_mask[:, :, 1] | (mask * color[1])
|
214 |
+
color_mask[:, :, 2] = color_mask[:, :, 2] | (mask * color[2])
|
215 |
+
|
216 |
+
blended = cv2.addWeighted(image_np, 1 - alpha, color_mask, alpha, 0)
|
217 |
+
return PILImage.fromarray(blended)
|
218 |
+
|
219 |
+
def run_pipeline(image: PILImage.Image, prompt: str):
|
220 |
+
if not model or not processor:
|
221 |
+
return "Models not loaded. Please check logs.", None
|
222 |
+
|
223 |
+
try:
|
224 |
+
img_original = image.copy()
|
225 |
+
img_resized = TF.resize(image, RESIZE_SIZE)
|
226 |
+
|
227 |
+
messages = prepare_test_messages(img_resized, prompt)
|
228 |
+
output_text = answer_question(messages)[0]
|
229 |
+
print(f"Model output: {output_text}")
|
230 |
+
|
231 |
+
points, labels, bbox = parse_custom_format(output_text)
|
232 |
+
|
233 |
+
mask_pred = None
|
234 |
+
final_mask = np.zeros(RESIZE_SIZE[::-1], dtype=bool)
|
235 |
+
|
236 |
+
if (points is not None and labels is not None) or (bbox is not None):
|
237 |
+
img = img_resized
|
238 |
+
|
239 |
+
if bbox is not None and len(bbox.shape) == 2:
|
240 |
+
for b in bbox:
|
241 |
+
b = b.tolist()
|
242 |
+
if points is not None and labels is not None:
|
243 |
+
in_bbox_mask = (
|
244 |
+
(points[:, 0] >= b[0]) & (points[:, 0] <= b[2]) &
|
245 |
+
(points[:, 1] >= b[1]) & (points[:, 1] <= b[3])
|
246 |
+
)
|
247 |
+
selected_points = points[in_bbox_mask]
|
248 |
+
selected_labels = labels[in_bbox_mask]
|
249 |
+
else:
|
250 |
+
selected_points, selected_labels = None, None
|
251 |
+
|
252 |
+
try:
|
253 |
+
mask, _ = sam_wrapper.predict(
|
254 |
+
img,
|
255 |
+
selected_points.tolist() if selected_points is not None and len(selected_points) > 0 else None,
|
256 |
+
selected_labels.tolist() if selected_labels is not None and len(selected_labels) > 0 else None,
|
257 |
+
b
|
258 |
+
)
|
259 |
+
final_mask |= (mask > 0)
|
260 |
+
except Exception as e:
|
261 |
+
print(f"Mask prediction error for bbox: {e}")
|
262 |
+
continue
|
263 |
+
|
264 |
+
mask_pred = final_mask
|
265 |
+
else:
|
266 |
+
try:
|
267 |
+
mask_pred, _ = sam_wrapper.predict(
|
268 |
+
img,
|
269 |
+
points.tolist() if points is not None else None,
|
270 |
+
labels.tolist() if labels is not None else None,
|
271 |
+
bbox.tolist() if bbox is not None else None
|
272 |
+
)
|
273 |
+
mask_pred = mask_pred > 0
|
274 |
+
except Exception as e:
|
275 |
+
print(f"Mask prediction error: {e}")
|
276 |
+
mask_pred = np.zeros(RESIZE_SIZE[::-1], dtype=bool)
|
277 |
+
else:
|
278 |
+
return output_text, None
|
279 |
+
|
280 |
+
# 将掩码调整回原始图像尺寸
|
281 |
+
mask_np = mask_pred
|
282 |
+
mask_img = PILImage.fromarray((mask_np * 255).astype(np.uint8)).resize(img_original.size)
|
283 |
+
mask_img = mask_img.convert("L")
|
284 |
+
mask_np = np.array(mask_img) > 128
|
285 |
+
|
286 |
+
# 可视化结果
|
287 |
+
visualized_img = visualize_masks_on_image(
|
288 |
+
img_original,
|
289 |
+
masks_np=[mask_np],
|
290 |
+
alpha=0.6
|
291 |
+
)
|
292 |
+
return output_text, visualized_img
|
293 |
+
except Exception as e:
|
294 |
+
print(f"Pipeline error: {e}")
|
295 |
+
return f"Error processing request: {str(e)}", None
|
296 |
+
|
297 |
+
# ------------------ 启动 Gradio ------------------
|
298 |
+
|
299 |
+
with gr.Blocks(title="Seg-R1") as demo:
|
300 |
+
gr.Markdown("# Seg-R1: Visual Segmentation Assistant")
|
301 |
+
gr.Markdown("Upload an image and ask questions about segmentation.")
|
302 |
+
|
303 |
+
with gr.Row():
|
304 |
+
with gr.Column():
|
305 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
306 |
+
text_input = gr.Textbox(lines=2, label="Question", placeholder="Ask about objects in the image...")
|
307 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
308 |
+
|
309 |
+
with gr.Column():
|
310 |
+
text_output = gr.Textbox(label="Model Response", interactive=False)
|
311 |
+
image_output = gr.Image(type="pil", label="Segmentation Result", interactive=False)
|
312 |
+
|
313 |
+
submit_btn.click(
|
314 |
+
fn=run_pipeline,
|
315 |
+
inputs=[image_input, text_input],
|
316 |
+
outputs=[text_output, image_output]
|
317 |
+
)
|
318 |
+
|
319 |
+
# gr.Examples(
|
320 |
+
# examples=[
|
321 |
+
# ["examples/dog.jpg", "Segment the dog in the image"],
|
322 |
+
# ["examples/street.jpg", "Find all cars in the image"],
|
323 |
+
# ["examples/fruits.jpg", "Identify the apples"]
|
324 |
+
# ],
|
325 |
+
# inputs=[image_input, text_input],
|
326 |
+
# outputs=[text_output, image_output],
|
327 |
+
# fn=run_pipeline,
|
328 |
+
# cache_examples=False
|
329 |
+
# )
|
330 |
+
|
331 |
+
if __name__ == "__main__":
|
332 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|