aiknowyou-nic commited on
Commit
3d0afce
·
1 Parent(s): 7fb5b53

new models

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +74 -2
  3. models/sam_vit_h_4b8939.pth +3 -0
  4. requirements.txt +9 -1
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/sam_vit_h_4b8939.pth filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,12 +1,84 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def pred(search_string, img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  return img
5
 
6
  iface = gr.Interface(
7
  fn=pred,
8
- inputs=["text", "image"],
9
- outputs="image",
10
  examples = [
11
  ["banana", "./imgs/test_1.jpg"],
12
  ["orange", "./imgs/test_1.jpg"],
 
1
  import gradio as gr
2
+ from segment_anything import build_sam, SamAutomaticMaskGenerator
3
+ from PIL import Image, ImageDraw
4
+ import clip
5
+ import torch
6
+ import numpy as np
7
+
8
+ mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="./models/sam_vit_h_4b8939.pth"))
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model, preprocess = clip.load("ViT-B/32", device=device)
12
+
13
+ def convert_box_xywh_to_xyxy(box):
14
+ x1 = box[0]
15
+ y1 = box[1]
16
+ x2 = box[0] + box[2]
17
+ y2 = box[1] + box[3]
18
+ return [x1, y1, x2, y2]
19
+
20
+ def segment_image(image, segmentation_mask):
21
+ image_array = np.array(image)
22
+ segmented_image_array = np.zeros_like(image_array)
23
+ segmented_image_array[segmentation_mask] = image_array[segmentation_mask]
24
+ segmented_image = Image.fromarray(segmented_image_array)
25
+ black_image = Image.new("RGB", image.size, (0, 0, 0))
26
+ transparency_mask = np.zeros_like(segmentation_mask, dtype=np.uint8)
27
+ transparency_mask[segmentation_mask] = 255
28
+ transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
29
+ black_image.paste(segmented_image, mask=transparency_mask_image)
30
+ return black_image
31
+
32
+ @torch.no_grad()
33
+ def retriev(elements: list[Image.Image], search_text: str) -> int:
34
+ preprocessed_images = [preprocess(image).to(device) for image in elements]
35
+ tokenized_text = clip.tokenize([search_text]).to(device)
36
+ stacked_images = torch.stack(preprocessed_images)
37
+ image_features = model.encode_image(stacked_images)
38
+ text_features = model.encode_text(tokenized_text)
39
+ image_features /= image_features.norm(dim=-1, keepdim=True)
40
+ text_features /= text_features.norm(dim=-1, keepdim=True)
41
+ probs = 100. * image_features @ text_features.T
42
+ return probs[:, 0].softmax(dim=0)
43
+
44
+ def get_indices_of_values_above_threshold(values, threshold):
45
+ return [i for i, v in enumerate(values) if v > threshold]
46
+
47
 
48
  def pred(search_string, img):
49
+ # original_image = img.copy()
50
+
51
+ # masks = mask_generator.generate(img)
52
+ # # Cut out all masks
53
+ # cropped_boxes = []
54
+
55
+ # for mask in masks:
56
+ # cropped_boxes.append(segment_image(img, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))
57
+
58
+ # scores = retriev(cropped_boxes, "kiwi")
59
+ # indices = get_indices_of_values_above_threshold(scores, 0.05)
60
+
61
+ # segmentation_masks = []
62
+
63
+ # for seg_idx in indices:
64
+ # segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
65
+ # segmentation_masks.append(segmentation_mask_image)
66
+
67
+ # overlay_image = Image.new('RGBA', img.size, (0, 0, 0, 0))
68
+ # overlay_color = (255, 0, 0, 200)
69
+
70
+ # draw = ImageDraw.Draw(overlay_image)
71
+ # for segmentation_mask_image in segmentation_masks:
72
+ # draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color)
73
+
74
+ # result_image = Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
75
+
76
  return img
77
 
78
  iface = gr.Interface(
79
  fn=pred,
80
+ inputs=["text", gr.inputs.Image(type="pil")],
81
+ outputs=gr.outputs.Image(type="pil"),
82
  examples = [
83
  ["banana", "./imgs/test_1.jpg"],
84
  ["orange", "./imgs/test_1.jpg"],
models/sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
requirements.txt CHANGED
@@ -1 +1,9 @@
1
- gradio
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+
3
+ torch
4
+ opencv-python
5
+ Pillow
6
+
7
+ git+https://github.com/openai/CLIP.git
8
+
9
+ git+https://github.com/facebookresearch/segment-anything.git