geshang commited on
Commit
7dacac2
·
verified ·
1 Parent(s): 64d205a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import gradio as gr
4
+ from PIL import Image as PILImage
5
+ import torchvision.transforms.functional as TF
6
+ import numpy as np
7
+ import torch
8
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
9
+ from qwen_vl_utils import process_vision_info
10
+ import re
11
+ import io
12
+ import base64
13
+ import cv2
14
+ from typing import List, Tuple, Optional
15
+
16
+ # ------------------ 安装SAM2依赖 ------------------
17
+ def install_sam2():
18
+ """检查并安装SAM2及其依赖"""
19
+ sam2_dir = "third_party/sam2"
20
+ if not os.path.exists(sam2_dir):
21
+ print("Installing SAM2...")
22
+ os.makedirs("third_party", exist_ok=True)
23
+ subprocess.run(["git", "clone", "https://github.com/facebookresearch/sam2.git", sam2_dir], check=True)
24
+
25
+ # 安装依赖
26
+ subprocess.run(["pip", "install", "-e", sam2_dir], check=True)
27
+
28
+ # 下载检查点
29
+ checkpoints_dir = os.path.join(sam2_dir, "checkpoints")
30
+ os.makedirs(checkpoints_dir, exist_ok=True)
31
+ subprocess.run([os.path.join(checkpoints_dir, "download_ckpts.sh")], cwd=checkpoints_dir, shell=True, check=True)
32
+ print("SAM2 installed successfully!")
33
+ else:
34
+ print("SAM2 already installed.")
35
+
36
+ # 确保安装SAM2
37
+ install_sam2()
38
+
39
+ # ------------------ 初始化模型 ------------------
40
+ # 使用相对路径
41
+ MODEL_PATH = "geshang/Seg-R1-COD"
42
+ SAM_CHECKPOINT = "third_party/sam2/checkpoints/sam2.1_hiera_large.pt"
43
+
44
+ # 自动检测设备
45
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
46
+ RESIZE_SIZE = (768, 768)
47
+
48
+ # 加载Qwen模型
49
+ try:
50
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
51
+ MODEL_PATH,
52
+ torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
53
+ device_map="auto" if DEVICE == "cuda" else None
54
+ ).to(DEVICE)
55
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
56
+ print(f"Qwen model loaded on {DEVICE}")
57
+ except Exception as e:
58
+ print(f"Error loading Qwen model: {e}")
59
+ # 创建虚拟模型以便继续运行
60
+ model = None
61
+ processor = None
62
+
63
+ # SAM Wrapper
64
+ class CustomSAMWrapper:
65
+ def __init__(self, model_path: str, device: str = DEVICE):
66
+ try:
67
+ from sam2.build_sam import build_sam2
68
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
69
+
70
+ self.device = torch.device(device)
71
+ model_cfg = os.path.join("third_party/sam2", "configs/sam2.1/sam2.1_hiera_l.yaml")
72
+ sam_model = build_sam2(model_cfg, model_path)
73
+ sam_model = sam_model.to(self.device)
74
+ self.predictor = SAM2ImagePredictor(sam_model)
75
+ self.last_mask = None
76
+ print(f"SAM model loaded on {device}")
77
+ except Exception as e:
78
+ print(f"Error loading SAM model: {e}")
79
+ self.predictor = None
80
+
81
+ def predict(self, image: PILImage.Image,
82
+ points: List[Tuple[int, int]],
83
+ labels: List[int],
84
+ bbox: Optional[List[List[int]]] = None) -> Tuple[np.ndarray, float]:
85
+ if not self.predictor:
86
+ return np.zeros((image.height, image.width), dtype=bool), 0.0
87
+
88
+ try:
89
+ input_points = np.array(points) if points else None
90
+ input_labels = np.array(labels) if labels else None
91
+ input_bboxes = np.array(bbox) if bbox else None
92
+
93
+ image_np = np.array(image)
94
+ rgb_image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
95
+
96
+ self.predictor.set_image(rgb_image)
97
+
98
+ mask_pred, score, logits = self.predictor.predict(
99
+ point_coords=input_points,
100
+ point_labels=input_labels,
101
+ box=input_bboxes,
102
+ multimask_output=False,
103
+ )
104
+
105
+ self.last_mask = mask_pred[0]
106
+ return mask_pred[0], score[0]
107
+ except Exception as e:
108
+ print(f"SAM prediction error: {e}")
109
+ return np.zeros((image.height, image.width), dtype=bool), 0.0
110
+
111
+ # 初始化SAM包装器
112
+ sam_wrapper = CustomSAMWrapper(SAM_CHECKPOINT, device=DEVICE)
113
+
114
+ # ------------------ 推理相关函数 ------------------
115
+
116
+ def parse_custom_format(content: str):
117
+ point_pattern = r"<points>\s*(\[\s*(?:\[\s*\d+\s*,\s*\d+\s*\]\s*,?\s*)+\])\s*</points>"
118
+ label_pattern = r"<labels>\s*(\[\s*(?:\d+\s*,?\s*)+\])\s*</labels>"
119
+ bbox_pattern = r"<bbox>\s*(\[\s*\d+\s*,\s*\d+\s*,\s*\d+\s*,\s*\d+\s*\])\s*</bbox>"
120
+
121
+ point_match = re.search(point_pattern, content)
122
+ label_match = re.search(label_pattern, content)
123
+ bbox_matches = re.findall(bbox_pattern, content)
124
+
125
+ try:
126
+ points = np.array(eval(point_match.group(1))) if point_match else None
127
+ labels = np.array(eval(label_match.group(1))) if label_match else None
128
+
129
+ if points is not None and labels is not None:
130
+ if not (len(points.shape) == 2 and points.shape[1] == 2 and len(labels) == points.shape[0]):
131
+ points, labels = None, None
132
+
133
+ bboxes = []
134
+ for bbox_str in bbox_matches:
135
+ bbox = np.array(eval(bbox_str))
136
+ if len(bbox.shape) == 1 and bbox.shape[0] == 4:
137
+ bboxes.append(bbox)
138
+
139
+ bboxes = np.stack(bboxes, axis=0) if bboxes else None
140
+
141
+ return points, labels, bboxes
142
+
143
+ except Exception as e:
144
+ print("Error parsing content:", e)
145
+ return None, None, None
146
+
147
+ def prepare_test_messages(image, prompt):
148
+ buffered = io.BytesIO()
149
+ image = TF.resize(image, RESIZE_SIZE)
150
+ image.save(buffered, format="JPEG")
151
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
152
+
153
+ if "segment" in prompt or "mask" in prompt:
154
+ SYSTEM_PROMPT = (
155
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
156
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
157
+ "process should enclosed within <think> </think> tags, and the bounding box, points and points labels should be enclosed within <bbox></bbox>, <points></points>, and <labels></labels>, respectively. i.e., "
158
+ "<think> reasoning process here </think> <bbox>[x1,y1,x2,y2]</bbox>, <points>[[x3,y3],[x4,y4],...]</points>, <labels>[1,0,...]</labels>"
159
+ "Where 1 indicates a foreground (object) point, and 0 indicates a background point."
160
+ )
161
+ else:
162
+ SYSTEM_PROMPT = "You're a helpful visual assistant."
163
+
164
+ messages = [
165
+ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
166
+ {
167
+ "role": "user",
168
+ "content": [
169
+ {"type": "image", "image": f"data:image/jpeg;base64,{img_base64}"},
170
+ {"type": "text", "text": prompt},
171
+ ],
172
+ },
173
+ ]
174
+ return [messages]
175
+
176
+ def answer_question(batch_messages):
177
+ if not model or not processor:
178
+ return ["Model not loaded. Please check logs."]
179
+
180
+ try:
181
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
182
+ image_inputs, video_inputs = process_vision_info(batch_messages)
183
+ inputs = processor(text=text, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True).to(DEVICE)
184
+ outputs = model.generate(**inputs, use_cache=True, max_new_tokens=1024)
185
+ trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, outputs)]
186
+ return processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
187
+ except Exception as e:
188
+ print(f"Error generating answer: {e}")
189
+ return ["Error generating response"]
190
+
191
+ def visualize_masks_on_image(
192
+ image: PILImage.Image,
193
+ masks_np: list,
194
+ colors=[(255, 0, 0), (0, 255, 0), (0, 0, 255),
195
+ (255, 255, 0), (255, 0, 255), (0, 255, 255),
196
+ (128, 128, 255)],
197
+ alpha=0.5,
198
+ ):
199
+ if not masks_np:
200
+ return image
201
+
202
+ image_np = np.array(image)
203
+ color_mask = np.zeros((image_np.shape[0], image_np.shape[1], 3), dtype=np.uint8)
204
+
205
+ for i, mask in enumerate(masks_np):
206
+ color = colors[i % len(colors)]
207
+ mask = mask.astype(np.uint8)
208
+
209
+ if mask.shape[:2] != image_np.shape[:2]:
210
+ mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]))
211
+
212
+ color_mask[:, :, 0] = color_mask[:, :, 0] | (mask * color[0])
213
+ color_mask[:, :, 1] = color_mask[:, :, 1] | (mask * color[1])
214
+ color_mask[:, :, 2] = color_mask[:, :, 2] | (mask * color[2])
215
+
216
+ blended = cv2.addWeighted(image_np, 1 - alpha, color_mask, alpha, 0)
217
+ return PILImage.fromarray(blended)
218
+
219
+ def run_pipeline(image: PILImage.Image, prompt: str):
220
+ if not model or not processor:
221
+ return "Models not loaded. Please check logs.", None
222
+
223
+ try:
224
+ img_original = image.copy()
225
+ img_resized = TF.resize(image, RESIZE_SIZE)
226
+
227
+ messages = prepare_test_messages(img_resized, prompt)
228
+ output_text = answer_question(messages)[0]
229
+ print(f"Model output: {output_text}")
230
+
231
+ points, labels, bbox = parse_custom_format(output_text)
232
+
233
+ mask_pred = None
234
+ final_mask = np.zeros(RESIZE_SIZE[::-1], dtype=bool)
235
+
236
+ if (points is not None and labels is not None) or (bbox is not None):
237
+ img = img_resized
238
+
239
+ if bbox is not None and len(bbox.shape) == 2:
240
+ for b in bbox:
241
+ b = b.tolist()
242
+ if points is not None and labels is not None:
243
+ in_bbox_mask = (
244
+ (points[:, 0] >= b[0]) & (points[:, 0] <= b[2]) &
245
+ (points[:, 1] >= b[1]) & (points[:, 1] <= b[3])
246
+ )
247
+ selected_points = points[in_bbox_mask]
248
+ selected_labels = labels[in_bbox_mask]
249
+ else:
250
+ selected_points, selected_labels = None, None
251
+
252
+ try:
253
+ mask, _ = sam_wrapper.predict(
254
+ img,
255
+ selected_points.tolist() if selected_points is not None and len(selected_points) > 0 else None,
256
+ selected_labels.tolist() if selected_labels is not None and len(selected_labels) > 0 else None,
257
+ b
258
+ )
259
+ final_mask |= (mask > 0)
260
+ except Exception as e:
261
+ print(f"Mask prediction error for bbox: {e}")
262
+ continue
263
+
264
+ mask_pred = final_mask
265
+ else:
266
+ try:
267
+ mask_pred, _ = sam_wrapper.predict(
268
+ img,
269
+ points.tolist() if points is not None else None,
270
+ labels.tolist() if labels is not None else None,
271
+ bbox.tolist() if bbox is not None else None
272
+ )
273
+ mask_pred = mask_pred > 0
274
+ except Exception as e:
275
+ print(f"Mask prediction error: {e}")
276
+ mask_pred = np.zeros(RESIZE_SIZE[::-1], dtype=bool)
277
+ else:
278
+ return output_text, None
279
+
280
+ # 将掩码调整回原始图像尺寸
281
+ mask_np = mask_pred
282
+ mask_img = PILImage.fromarray((mask_np * 255).astype(np.uint8)).resize(img_original.size)
283
+ mask_img = mask_img.convert("L")
284
+ mask_np = np.array(mask_img) > 128
285
+
286
+ # 可视化结果
287
+ visualized_img = visualize_masks_on_image(
288
+ img_original,
289
+ masks_np=[mask_np],
290
+ alpha=0.6
291
+ )
292
+ return output_text, visualized_img
293
+ except Exception as e:
294
+ print(f"Pipeline error: {e}")
295
+ return f"Error processing request: {str(e)}", None
296
+
297
+ # ------------------ 启动 Gradio ------------------
298
+
299
+ with gr.Blocks(title="Seg-R1") as demo:
300
+ gr.Markdown("# Seg-R1: Visual Segmentation Assistant")
301
+ gr.Markdown("Upload an image and ask questions about segmentation.")
302
+
303
+ with gr.Row():
304
+ with gr.Column():
305
+ image_input = gr.Image(type="pil", label="Upload Image")
306
+ text_input = gr.Textbox(lines=2, label="Question", placeholder="Ask about objects in the image...")
307
+ submit_btn = gr.Button("Submit", variant="primary")
308
+
309
+ with gr.Column():
310
+ text_output = gr.Textbox(label="Model Response", interactive=False)
311
+ image_output = gr.Image(type="pil", label="Segmentation Result", interactive=False)
312
+
313
+ submit_btn.click(
314
+ fn=run_pipeline,
315
+ inputs=[image_input, text_input],
316
+ outputs=[text_output, image_output]
317
+ )
318
+
319
+ # gr.Examples(
320
+ # examples=[
321
+ # ["examples/dog.jpg", "Segment the dog in the image"],
322
+ # ["examples/street.jpg", "Find all cars in the image"],
323
+ # ["examples/fruits.jpg", "Identify the apples"]
324
+ # ],
325
+ # inputs=[image_input, text_input],
326
+ # outputs=[text_output, image_output],
327
+ # fn=run_pipeline,
328
+ # cache_examples=False
329
+ # )
330
+
331
+ if __name__ == "__main__":
332
+ demo.launch(server_name="0.0.0.0", server_port=7860)