Spaces:
Running
Running
Commit
·
8f260ec
1
Parent(s):
a77f1f0
files2
Browse files- .gitignore +9 -0
- configs/__init__.py +0 -0
- configs/config.py +14 -0
- models/__init__.py +0 -0
- models/checkpoint/FastSAM-s.pt +3 -0
- models/checkpoint/FastSAM-x.pt +3 -0
- models/checkpoint/FastSAM.pt +3 -0
- models/checkpoint/__init__.py +0 -0
- models/model.py +37 -0
- models/preprocess.py +17 -0
- tools/__init__.py +0 -0
- tools/tools.py +10 -0
- visualize/__init__.py +0 -0
- visualize/visualize.py +25 -0
.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/__pycache__
|
2 |
+
/__pycache__/*
|
3 |
+
*/__pycache__/*
|
4 |
+
*/__pycache__*
|
5 |
+
|
6 |
+
|
7 |
+
/.ipynb_checkpoints
|
8 |
+
/.ipynb_checkpoints/*
|
9 |
+
/notebook/.ipynb_checkpoints
|
configs/__init__.py
ADDED
File without changes
|
configs/config.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
|
4 |
+
DEVICE = torch.device("cpu")
|
5 |
+
|
6 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
7 |
+
BACKEND_DIR = os.path.abspath(os.path.join(BASE_DIR, '..'))
|
8 |
+
|
9 |
+
MODEL_PATH = os.path.join(BACKEND_DIR, 'models', 'checkpoint')
|
10 |
+
|
11 |
+
FAST_SAM_IMGSZ = 1024
|
12 |
+
FAST_SAM_CONF = 0.5
|
13 |
+
FAST_SAM_IOU = 0.6
|
14 |
+
FAST_SAM_RETINA_MASKS = True
|
models/__init__.py
ADDED
File without changes
|
models/checkpoint/FastSAM-s.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c9f78716a81c7aff0d608ccc73e1b82ab3aaad86005049f6a92106a0be6d0844
|
3 |
+
size 23851578
|
models/checkpoint/FastSAM-x.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:752cadc2828edb1cd4bc4f9eb587100631af06ea2108f4c9ed56df4755701e76
|
3 |
+
size 144972346
|
models/checkpoint/FastSAM.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0be4e7ddbe4c15333d15a859c676d053c486d0a746a3be6a7a9790d52a9b6d7
|
3 |
+
size 144943063
|
models/checkpoint/__init__.py
ADDED
File without changes
|
models/model.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List
|
2 |
+
from fastsam import FastSAM
|
3 |
+
from PIL import Image
|
4 |
+
from fastsam import FastSAM, FastSAMPrompt
|
5 |
+
import numpy as np
|
6 |
+
from configs.config import DEVICE, FAST_SAM_CONF, FAST_SAM_IMGSZ, FAST_SAM_IOU, FAST_SAM_RETINA_MASKS, MODEL_PATH
|
7 |
+
from PIL import Image
|
8 |
+
import os
|
9 |
+
|
10 |
+
|
11 |
+
def loadModel(model_name: str = "FastSAM-x.pt") -> Any:
|
12 |
+
path: str = MODEL_PATH + "/" + model_name
|
13 |
+
if not os.path.exists(path):
|
14 |
+
downloadModel(path)
|
15 |
+
return FastSAM(f"{MODEL_PATH}/{model_name}")
|
16 |
+
|
17 |
+
|
18 |
+
def getMask(image_path: Image, fast_sam: FastSAM, point: List[List[int]], point_label: List[int]) -> Any:
|
19 |
+
result: Any = fast_sam(
|
20 |
+
source=image_path,
|
21 |
+
device=DEVICE,
|
22 |
+
retina_masks=FAST_SAM_RETINA_MASKS,
|
23 |
+
imgsz=FAST_SAM_IMGSZ,
|
24 |
+
conf=FAST_SAM_CONF,
|
25 |
+
iou=FAST_SAM_IOU,
|
26 |
+
)
|
27 |
+
prompt_process = FastSAMPrompt(image_path, result, device=DEVICE)
|
28 |
+
return prompt_process.point_prompt(points=point, pointlabel=point_label)
|
29 |
+
|
30 |
+
|
31 |
+
def downloadModel(model_name):
|
32 |
+
import requests
|
33 |
+
url = "https://firebasestorage.googleapis.com/v0/b/lexicons-5.appspot.com/o/FastSam-Models%2FFastSAM-x.pt?alt=media&token=64b65560-17d6-47b0-8a2b-8e2ee096da64"
|
34 |
+
r = requests.get(url)
|
35 |
+
with open(model_name, 'wb') as f:
|
36 |
+
f.write(r.content)
|
37 |
+
return model_name
|
models/preprocess.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
from fastsam import FastSAM, FastSAMPrompt
|
3 |
+
import numpy as np
|
4 |
+
from configs.config import DEVICE, FAST_SAM_CONF, FAST_SAM_IMGSZ, FAST_SAM_IOU, FAST_SAM_RETINA_MASKS
|
5 |
+
|
6 |
+
def preprocess(points_data: List[Dict]) -> Any:
|
7 |
+
|
8 |
+
input_points = []
|
9 |
+
input_labels = []
|
10 |
+
|
11 |
+
for point in points_data:
|
12 |
+
input_points.append([int(point['x_']), int(point['y_'])])
|
13 |
+
input_labels.append(int(point['flag_']))
|
14 |
+
|
15 |
+
return input_points, input_labels
|
16 |
+
|
17 |
+
|
tools/__init__.py
ADDED
File without changes
|
tools/tools.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import base64
|
3 |
+
import io
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
def convertToBuffer(image : Any) -> bytes:
|
7 |
+
buffered : io.BytesIO = io.BytesIO()
|
8 |
+
image.save(buffered, format="PNG")
|
9 |
+
img_base64 : bytes = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
10 |
+
return img_base64
|
visualize/__init__.py
ADDED
File without changes
|
visualize/visualize.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from typing import Any, Dict, List
|
6 |
+
|
7 |
+
def removeBgFromSegmentImage(og_image : Any, og_mask : Any, color: tuple = (0, 0, 255, 255), opacity: float = 0.2) -> Image:
|
8 |
+
og_image = np.array(og_image.convert('RGB'))
|
9 |
+
mask = og_mask.astype(np.uint8) * 255 # Convert to 0 or 255
|
10 |
+
|
11 |
+
rgba_image = np.zeros((og_image.shape[0], og_image.shape[1], 4), dtype=np.uint8)
|
12 |
+
|
13 |
+
color_with_opacity = (color[0], color[1], color[2], int(color[3] * opacity))
|
14 |
+
rgba_image[mask > 0] = color_with_opacity
|
15 |
+
|
16 |
+
return Image.fromarray(rgba_image)
|
17 |
+
|
18 |
+
def removeOnlyBg(og_image : Any, og_mask : Any) -> Image:
|
19 |
+
img = np.array(og_image.convert('RGB'))
|
20 |
+
mask = cv2.resize(og_mask.astype(np.uint8), (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
|
21 |
+
rgba_image = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
|
22 |
+
rgba_image[..., :3] = img
|
23 |
+
rgba_image[..., 3] = mask * 255
|
24 |
+
|
25 |
+
return Image.fromarray(rgba_image)
|