Commit
·
3d0afce
1
Parent(s):
7fb5b53
new models
Browse files- .gitattributes +1 -0
- app.py +74 -2
- models/sam_vit_h_4b8939.pth +3 -0
- requirements.txt +9 -1
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
models/sam_vit_h_4b8939.pth filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,12 +1,84 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
def pred(search_string, img):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
return img
|
5 |
|
6 |
iface = gr.Interface(
|
7 |
fn=pred,
|
8 |
-
inputs=["text", "
|
9 |
-
outputs="
|
10 |
examples = [
|
11 |
["banana", "./imgs/test_1.jpg"],
|
12 |
["orange", "./imgs/test_1.jpg"],
|
|
|
1 |
import gradio as gr
|
2 |
+
from segment_anything import build_sam, SamAutomaticMaskGenerator
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
+
import clip
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="./models/sam_vit_h_4b8939.pth"))
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
12 |
+
|
13 |
+
def convert_box_xywh_to_xyxy(box):
|
14 |
+
x1 = box[0]
|
15 |
+
y1 = box[1]
|
16 |
+
x2 = box[0] + box[2]
|
17 |
+
y2 = box[1] + box[3]
|
18 |
+
return [x1, y1, x2, y2]
|
19 |
+
|
20 |
+
def segment_image(image, segmentation_mask):
|
21 |
+
image_array = np.array(image)
|
22 |
+
segmented_image_array = np.zeros_like(image_array)
|
23 |
+
segmented_image_array[segmentation_mask] = image_array[segmentation_mask]
|
24 |
+
segmented_image = Image.fromarray(segmented_image_array)
|
25 |
+
black_image = Image.new("RGB", image.size, (0, 0, 0))
|
26 |
+
transparency_mask = np.zeros_like(segmentation_mask, dtype=np.uint8)
|
27 |
+
transparency_mask[segmentation_mask] = 255
|
28 |
+
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
|
29 |
+
black_image.paste(segmented_image, mask=transparency_mask_image)
|
30 |
+
return black_image
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def retriev(elements: list[Image.Image], search_text: str) -> int:
|
34 |
+
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
35 |
+
tokenized_text = clip.tokenize([search_text]).to(device)
|
36 |
+
stacked_images = torch.stack(preprocessed_images)
|
37 |
+
image_features = model.encode_image(stacked_images)
|
38 |
+
text_features = model.encode_text(tokenized_text)
|
39 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
40 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
41 |
+
probs = 100. * image_features @ text_features.T
|
42 |
+
return probs[:, 0].softmax(dim=0)
|
43 |
+
|
44 |
+
def get_indices_of_values_above_threshold(values, threshold):
|
45 |
+
return [i for i, v in enumerate(values) if v > threshold]
|
46 |
+
|
47 |
|
48 |
def pred(search_string, img):
|
49 |
+
# original_image = img.copy()
|
50 |
+
|
51 |
+
# masks = mask_generator.generate(img)
|
52 |
+
# # Cut out all masks
|
53 |
+
# cropped_boxes = []
|
54 |
+
|
55 |
+
# for mask in masks:
|
56 |
+
# cropped_boxes.append(segment_image(img, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))
|
57 |
+
|
58 |
+
# scores = retriev(cropped_boxes, "kiwi")
|
59 |
+
# indices = get_indices_of_values_above_threshold(scores, 0.05)
|
60 |
+
|
61 |
+
# segmentation_masks = []
|
62 |
+
|
63 |
+
# for seg_idx in indices:
|
64 |
+
# segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
|
65 |
+
# segmentation_masks.append(segmentation_mask_image)
|
66 |
+
|
67 |
+
# overlay_image = Image.new('RGBA', img.size, (0, 0, 0, 0))
|
68 |
+
# overlay_color = (255, 0, 0, 200)
|
69 |
+
|
70 |
+
# draw = ImageDraw.Draw(overlay_image)
|
71 |
+
# for segmentation_mask_image in segmentation_masks:
|
72 |
+
# draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color)
|
73 |
+
|
74 |
+
# result_image = Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
|
75 |
+
|
76 |
return img
|
77 |
|
78 |
iface = gr.Interface(
|
79 |
fn=pred,
|
80 |
+
inputs=["text", gr.inputs.Image(type="pil")],
|
81 |
+
outputs=gr.outputs.Image(type="pil"),
|
82 |
examples = [
|
83 |
["banana", "./imgs/test_1.jpg"],
|
84 |
["orange", "./imgs/test_1.jpg"],
|
models/sam_vit_h_4b8939.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
+
size 2564550879
|
requirements.txt
CHANGED
@@ -1 +1,9 @@
|
|
1 |
-
gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
|
3 |
+
torch
|
4 |
+
opencv-python
|
5 |
+
Pillow
|
6 |
+
|
7 |
+
git+https://github.com/openai/CLIP.git
|
8 |
+
|
9 |
+
git+https://github.com/facebookresearch/segment-anything.git
|