Commit
·
5875f7d
1
Parent(s):
3d0afce
coded
Browse files- app.py +22 -20
- packages.txt +1 -0
app.py
CHANGED
@@ -5,6 +5,7 @@ import clip
|
|
5 |
import torch
|
6 |
import numpy as np
|
7 |
|
|
|
8 |
mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="./models/sam_vit_h_4b8939.pth"))
|
9 |
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -46,39 +47,40 @@ def get_indices_of_values_above_threshold(values, threshold):
|
|
46 |
|
47 |
|
48 |
def pred(search_string, img):
|
49 |
-
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
#
|
|
|
54 |
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
|
61 |
-
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
|
67 |
-
|
68 |
-
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
|
74 |
-
|
75 |
|
76 |
-
return
|
77 |
|
78 |
iface = gr.Interface(
|
79 |
fn=pred,
|
80 |
inputs=["text", gr.inputs.Image(type="pil")],
|
81 |
-
outputs=gr.outputs.Image(type="pil"),
|
82 |
examples = [
|
83 |
["banana", "./imgs/test_1.jpg"],
|
84 |
["orange", "./imgs/test_1.jpg"],
|
|
|
5 |
import torch
|
6 |
import numpy as np
|
7 |
|
8 |
+
# preso spunto da https://github.com/maxi-w/CLIP-SAM/blob/main/main.ipynb
|
9 |
mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="./models/sam_vit_h_4b8939.pth"))
|
10 |
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
47 |
|
48 |
|
49 |
def pred(search_string, img):
|
50 |
+
original_image = img.copy()
|
51 |
|
52 |
+
open_cv_image = np.array(img)[:, :, ::-1]
|
53 |
+
masks = mask_generator.generate(open_cv_image)
|
54 |
+
# Cut out all masks
|
55 |
+
cropped_boxes = []
|
56 |
|
57 |
+
for mask in masks:
|
58 |
+
cropped_boxes.append(segment_image(img, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))
|
59 |
|
60 |
+
scores = retriev(cropped_boxes, search_string)
|
61 |
+
indices = get_indices_of_values_above_threshold(scores, 0.05)
|
62 |
|
63 |
+
segmentation_masks = []
|
64 |
|
65 |
+
for seg_idx in indices:
|
66 |
+
segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
|
67 |
+
segmentation_masks.append(segmentation_mask_image)
|
68 |
|
69 |
+
overlay_image = Image.new('RGBA', img.size, (0, 0, 0, 0))
|
70 |
+
overlay_color = (255, 0, 0, 200)
|
71 |
|
72 |
+
draw = ImageDraw.Draw(overlay_image)
|
73 |
+
for segmentation_mask_image in segmentation_masks:
|
74 |
+
draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color)
|
75 |
|
76 |
+
result_image = Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
|
77 |
|
78 |
+
return result_image, overlay_image
|
79 |
|
80 |
iface = gr.Interface(
|
81 |
fn=pred,
|
82 |
inputs=["text", gr.inputs.Image(type="pil")],
|
83 |
+
outputs=[gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil")],
|
84 |
examples = [
|
85 |
["banana", "./imgs/test_1.jpg"],
|
86 |
["orange", "./imgs/test_1.jpg"],
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python3-opencv
|