aiknowyou-nic commited on
Commit
5875f7d
·
1 Parent(s): 3d0afce
Files changed (2) hide show
  1. app.py +22 -20
  2. packages.txt +1 -0
app.py CHANGED
@@ -5,6 +5,7 @@ import clip
5
  import torch
6
  import numpy as np
7
 
 
8
  mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="./models/sam_vit_h_4b8939.pth"))
9
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -46,39 +47,40 @@ def get_indices_of_values_above_threshold(values, threshold):
46
 
47
 
48
  def pred(search_string, img):
49
- # original_image = img.copy()
50
 
51
- # masks = mask_generator.generate(img)
52
- # # Cut out all masks
53
- # cropped_boxes = []
 
54
 
55
- # for mask in masks:
56
- # cropped_boxes.append(segment_image(img, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))
57
 
58
- # scores = retriev(cropped_boxes, "kiwi")
59
- # indices = get_indices_of_values_above_threshold(scores, 0.05)
60
 
61
- # segmentation_masks = []
62
 
63
- # for seg_idx in indices:
64
- # segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
65
- # segmentation_masks.append(segmentation_mask_image)
66
 
67
- # overlay_image = Image.new('RGBA', img.size, (0, 0, 0, 0))
68
- # overlay_color = (255, 0, 0, 200)
69
 
70
- # draw = ImageDraw.Draw(overlay_image)
71
- # for segmentation_mask_image in segmentation_masks:
72
- # draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color)
73
 
74
- # result_image = Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
75
 
76
- return img
77
 
78
  iface = gr.Interface(
79
  fn=pred,
80
  inputs=["text", gr.inputs.Image(type="pil")],
81
- outputs=gr.outputs.Image(type="pil"),
82
  examples = [
83
  ["banana", "./imgs/test_1.jpg"],
84
  ["orange", "./imgs/test_1.jpg"],
 
5
  import torch
6
  import numpy as np
7
 
8
+ # preso spunto da https://github.com/maxi-w/CLIP-SAM/blob/main/main.ipynb
9
  mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="./models/sam_vit_h_4b8939.pth"))
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
47
 
48
 
49
  def pred(search_string, img):
50
+ original_image = img.copy()
51
 
52
+ open_cv_image = np.array(img)[:, :, ::-1]
53
+ masks = mask_generator.generate(open_cv_image)
54
+ # Cut out all masks
55
+ cropped_boxes = []
56
 
57
+ for mask in masks:
58
+ cropped_boxes.append(segment_image(img, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))
59
 
60
+ scores = retriev(cropped_boxes, search_string)
61
+ indices = get_indices_of_values_above_threshold(scores, 0.05)
62
 
63
+ segmentation_masks = []
64
 
65
+ for seg_idx in indices:
66
+ segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
67
+ segmentation_masks.append(segmentation_mask_image)
68
 
69
+ overlay_image = Image.new('RGBA', img.size, (0, 0, 0, 0))
70
+ overlay_color = (255, 0, 0, 200)
71
 
72
+ draw = ImageDraw.Draw(overlay_image)
73
+ for segmentation_mask_image in segmentation_masks:
74
+ draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color)
75
 
76
+ result_image = Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
77
 
78
+ return result_image, overlay_image
79
 
80
  iface = gr.Interface(
81
  fn=pred,
82
  inputs=["text", gr.inputs.Image(type="pil")],
83
+ outputs=[gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil")],
84
  examples = [
85
  ["banana", "./imgs/test_1.jpg"],
86
  ["orange", "./imgs/test_1.jpg"],
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv