Tharuneshwar commited on
Commit
8f260ec
·
1 Parent(s): a77f1f0
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /__pycache__
2
+ /__pycache__/*
3
+ */__pycache__/*
4
+ */__pycache__*
5
+
6
+
7
+ /.ipynb_checkpoints
8
+ /.ipynb_checkpoints/*
9
+ /notebook/.ipynb_checkpoints
configs/__init__.py ADDED
File without changes
configs/config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ DEVICE = torch.device("cpu")
5
+
6
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
+ BACKEND_DIR = os.path.abspath(os.path.join(BASE_DIR, '..'))
8
+
9
+ MODEL_PATH = os.path.join(BACKEND_DIR, 'models', 'checkpoint')
10
+
11
+ FAST_SAM_IMGSZ = 1024
12
+ FAST_SAM_CONF = 0.5
13
+ FAST_SAM_IOU = 0.6
14
+ FAST_SAM_RETINA_MASKS = True
models/__init__.py ADDED
File without changes
models/checkpoint/FastSAM-s.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9f78716a81c7aff0d608ccc73e1b82ab3aaad86005049f6a92106a0be6d0844
3
+ size 23851578
models/checkpoint/FastSAM-x.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:752cadc2828edb1cd4bc4f9eb587100631af06ea2108f4c9ed56df4755701e76
3
+ size 144972346
models/checkpoint/FastSAM.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0be4e7ddbe4c15333d15a859c676d053c486d0a746a3be6a7a9790d52a9b6d7
3
+ size 144943063
models/checkpoint/__init__.py ADDED
File without changes
models/model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+ from fastsam import FastSAM
3
+ from PIL import Image
4
+ from fastsam import FastSAM, FastSAMPrompt
5
+ import numpy as np
6
+ from configs.config import DEVICE, FAST_SAM_CONF, FAST_SAM_IMGSZ, FAST_SAM_IOU, FAST_SAM_RETINA_MASKS, MODEL_PATH
7
+ from PIL import Image
8
+ import os
9
+
10
+
11
+ def loadModel(model_name: str = "FastSAM-x.pt") -> Any:
12
+ path: str = MODEL_PATH + "/" + model_name
13
+ if not os.path.exists(path):
14
+ downloadModel(path)
15
+ return FastSAM(f"{MODEL_PATH}/{model_name}")
16
+
17
+
18
+ def getMask(image_path: Image, fast_sam: FastSAM, point: List[List[int]], point_label: List[int]) -> Any:
19
+ result: Any = fast_sam(
20
+ source=image_path,
21
+ device=DEVICE,
22
+ retina_masks=FAST_SAM_RETINA_MASKS,
23
+ imgsz=FAST_SAM_IMGSZ,
24
+ conf=FAST_SAM_CONF,
25
+ iou=FAST_SAM_IOU,
26
+ )
27
+ prompt_process = FastSAMPrompt(image_path, result, device=DEVICE)
28
+ return prompt_process.point_prompt(points=point, pointlabel=point_label)
29
+
30
+
31
+ def downloadModel(model_name):
32
+ import requests
33
+ url = "https://firebasestorage.googleapis.com/v0/b/lexicons-5.appspot.com/o/FastSam-Models%2FFastSAM-x.pt?alt=media&token=64b65560-17d6-47b0-8a2b-8e2ee096da64"
34
+ r = requests.get(url)
35
+ with open(model_name, 'wb') as f:
36
+ f.write(r.content)
37
+ return model_name
models/preprocess.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ from fastsam import FastSAM, FastSAMPrompt
3
+ import numpy as np
4
+ from configs.config import DEVICE, FAST_SAM_CONF, FAST_SAM_IMGSZ, FAST_SAM_IOU, FAST_SAM_RETINA_MASKS
5
+
6
+ def preprocess(points_data: List[Dict]) -> Any:
7
+
8
+ input_points = []
9
+ input_labels = []
10
+
11
+ for point in points_data:
12
+ input_points.append([int(point['x_']), int(point['y_'])])
13
+ input_labels.append(int(point['flag_']))
14
+
15
+ return input_points, input_labels
16
+
17
+
tools/__init__.py ADDED
File without changes
tools/tools.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import base64
3
+ import io
4
+ from typing import Any
5
+
6
+ def convertToBuffer(image : Any) -> bytes:
7
+ buffered : io.BytesIO = io.BytesIO()
8
+ image.save(buffered, format="PNG")
9
+ img_base64 : bytes = base64.b64encode(buffered.getvalue()).decode('utf-8')
10
+ return img_base64
visualize/__init__.py ADDED
File without changes
visualize/visualize.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from typing import Any, Dict, List
6
+
7
+ def removeBgFromSegmentImage(og_image : Any, og_mask : Any, color: tuple = (0, 0, 255, 255), opacity: float = 0.2) -> Image:
8
+ og_image = np.array(og_image.convert('RGB'))
9
+ mask = og_mask.astype(np.uint8) * 255 # Convert to 0 or 255
10
+
11
+ rgba_image = np.zeros((og_image.shape[0], og_image.shape[1], 4), dtype=np.uint8)
12
+
13
+ color_with_opacity = (color[0], color[1], color[2], int(color[3] * opacity))
14
+ rgba_image[mask > 0] = color_with_opacity
15
+
16
+ return Image.fromarray(rgba_image)
17
+
18
+ def removeOnlyBg(og_image : Any, og_mask : Any) -> Image:
19
+ img = np.array(og_image.convert('RGB'))
20
+ mask = cv2.resize(og_mask.astype(np.uint8), (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
21
+ rgba_image = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
22
+ rgba_image[..., :3] = img
23
+ rgba_image[..., 3] = mask * 255
24
+
25
+ return Image.fromarray(rgba_image)