NHLOCAL commited on
Commit
d644cdb
ยท
1 Parent(s): cc9e884

Update backend.py

Browse files
Files changed (1) hide show
  1. backend.py +29 -17
backend.py CHANGED
@@ -155,33 +155,45 @@ CONF_THRESHOLD = 0.2
155
  # -----------------------------
156
  # 4) ื”ื›ื ื” ืœ-SAM2
157
  # -----------------------------
158
- try:
159
- from sam2.build_sam import build_sam2
160
- from sam2.sam2_image_predictor import SAM2ImagePredictor
161
 
162
- SAM2_CHECKPOINT = "../../checkpoints/sam2.1_hiera_tiny.pt"
163
- SAM2_CONFIG = "../../configs/sam2.1/sam2.1_hiera_t.yaml"
 
 
 
 
 
164
 
165
- if not os.path.exists(SAM2_CHECKPOINT):
 
 
 
 
 
166
  print("[SAM2] ืžื•ื“ืœ SAM2 ืœื ื ืžืฆื. ืžื ืกื” ืœื”ื•ืจื™ื“ ืืช ื”ืžื•ื“ืœ...")
167
- sam2_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt"
168
- os.makedirs(os.path.dirname(SAM2_CHECKPOINT), exist_ok=True)
169
  response = requests.get(sam2_url)
170
- with open(SAM2_CHECKPOINT, 'wb') as f:
171
  f.write(response.content)
172
  print("[SAM2] ืžื•ื“ืœ SAM2 ื”ื•ืจื“ ื‘ื”ืฆืœื—ื”.")
173
-
174
- if not os.path.exists(SAM2_CONFIG):
175
  print("[SAM2] ืงื•ื‘ืฅ ื”ืงื•ื ืคื™ื’ SAM2 ืœื ื ืžืฆื. ืžื ืกื” ืœื”ื•ืจื™ื“ ืืช ื”ืงื•ื ืคื™ื’...")
176
- # ื›ืืŸ ื™ืฉ ืœื•ื•ื“ื ืฉืงื•ื‘ืฅ ื”ืงื•ื ืคื™ื’ ื–ืžื™ืŸ, ืื• ืœื›ืœื•ืœ ืื•ืชื• ื‘ืžืื’ืจ ืฉืœืš
177
- raise FileNotFoundError(f"ืœื ื ืžืฆื ืงื•ื‘ืฅ ืงื•ื ืคื™ื’ SAM2 ื‘ื ืชื™ื‘: {SAM2_CONFIG}")
 
 
 
 
178
 
179
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
180
 
181
- sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)
182
- sam2_predictor = SAM2ImagePredictor(sam2_model)
 
183
  sam2_predictor.model.to(device)
184
-
185
  except Exception as e:
186
  print(f"[SAM2] ืœื ืžืฆืœื™ื— ืœื˜ืขื•ืŸ ืืช SAM2: {e}")
187
  sam2_predictor = None
 
155
  # -----------------------------
156
  # 4) ื”ื›ื ื” ืœ-SAM2
157
  # -----------------------------
 
 
 
158
 
159
+ from typing import Any
160
+ import supervision as sv
161
+ from sam2.build_sam import build_sam2
162
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
163
+
164
+ SAM2_CHECKPOINT = "checkpoints/sam2_hiera_small.pt"
165
+ SAM2_CONFIG = "sam2_hiera_s.yaml"
166
 
167
+ def load_sam_image_model(
168
+ device: torch.device,
169
+ config: str = SAM2_CONFIG,
170
+ checkpoint: str = SAM2_CHECKPOINT
171
+ ) -> SAM2ImagePredictor:
172
+ if not os.path.exists(checkpoint):
173
  print("[SAM2] ืžื•ื“ืœ SAM2 ืœื ื ืžืฆื. ืžื ืกื” ืœื”ื•ืจื™ื“ ืืช ื”ืžื•ื“ืœ...")
174
+ sam2_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2_hiera_small.pt"
175
+ os.makedirs(os.path.dirname(checkpoint), exist_ok=True)
176
  response = requests.get(sam2_url)
177
+ with open(checkpoint, 'wb') as f:
178
  f.write(response.content)
179
  print("[SAM2] ืžื•ื“ืœ SAM2 ื”ื•ืจื“ ื‘ื”ืฆืœื—ื”.")
180
+
181
+ if not os.path.exists(config):
182
  print("[SAM2] ืงื•ื‘ืฅ ื”ืงื•ื ืคื™ื’ SAM2 ืœื ื ืžืฆื. ืžื ืกื” ืœื”ื•ืจื™ื“ ืืช ื”ืงื•ื ืคื™ื’...")
183
+ sam2_config_url = "https://path_to_your_config/sam2_hiera_s.yaml" # ืขื“ื›ืŸ ืืช ื”ืงื™ืฉื•ืจ ื”ืžืชืื™ื
184
+ os.makedirs(os.path.dirname(config), exist_ok=True)
185
+ response = requests.get(sam2_config_url)
186
+ with open(config, 'wb') as f:
187
+ f.write(response.content)
188
+ print("[SAM2] ืงื•ื‘ืฅ ื”ืงื•ื ืคื™ื’ SAM2 ื”ื•ืจื“ ื‘ื”ืฆืœื—ื”.")
189
 
190
+ model = build_sam2(config, checkpoint, device=device)
191
+ return SAM2ImagePredictor(sam_model=model)
192
 
193
+ try:
194
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
195
+ sam2_predictor = load_sam_image_model(device=device)
196
  sam2_predictor.model.to(device)
 
197
  except Exception as e:
198
  print(f"[SAM2] ืœื ืžืฆืœื™ื— ืœื˜ืขื•ืŸ ืืช SAM2: {e}")
199
  sam2_predictor = None