sezer91 commited on
Commit
bdad001
·
1 Parent(s): ea3ca30

Add application file

Browse files
Files changed (2) hide show
  1. app.py +50 -18
  2. requirements.txt +6 -0
app.py CHANGED
@@ -1,20 +1,52 @@
1
- import matplotlib.pyplot as plt
2
- from PIL import Image
 
3
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- def show_mask(mask, ax, random_color=False):
6
- if random_color:
7
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
8
- else:
9
- color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
10
- h, w = mask.shape[-2:]
11
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
12
- ax.imshow(mask_image)
13
-
14
-
15
- plt.imshow(np.array(raw_image))
16
- ax = plt.gca()
17
- for mask in outputs["masks"]:
18
- show_mask(mask, ax=ax, random_color=True)
19
- plt.axis("off")
20
- plt.show()
 
1
+ import gradio as gr
2
+ import torch
3
+ from segment_anything import sam_model_registry, SamPredictor
4
  import numpy as np
5
+ import cv2
6
+ from PIL import Image
7
+
8
+ import os
9
+ import urllib.request
10
+
11
+ MODEL_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
12
+ MODEL_PATH = "sam_vit_b.pth"
13
+
14
+ # Eğer model yoksa indir
15
+ if not os.path.exists(MODEL_PATH):
16
+ print("Model indiriliyor...")
17
+ urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
18
+ print("Model indirildi.")
19
+
20
+ # Model yükle
21
+ model_type = "vit_b"
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ sam = sam_model_registry[model_type](checkpoint=MODEL_PATH)
25
+ sam.to(device=device)
26
+ predictor = SamPredictor(sam)
27
+
28
+ def segment(image, x, y):
29
+ image = np.array(image)
30
+ predictor.set_image(image)
31
+ input_point = np.array([[x, y]])
32
+ input_label = np.array([1])
33
+ masks, _, _ = predictor.predict(
34
+ point_coords=input_point,
35
+ point_labels=input_label,
36
+ multimask_output=False,
37
+ )
38
+ mask = masks[0]
39
+ masked_image = image.copy()
40
+ masked_image[~mask] = 0
41
+ return Image.fromarray(masked_image)
42
+
43
+ with gr.Blocks() as demo:
44
+ with gr.Row():
45
+ image_input = gr.Image(type="pil")
46
+ x = gr.Number(label="X")
47
+ y = gr.Number(label="Y")
48
+ btn = gr.Button("Segment")
49
+ output = gr.Image()
50
+ btn.click(fn=segment, inputs=[image_input, x, y], outputs=output)
51
 
52
+ demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ opencv-python
4
+ numpy
5
+ Pillow
6
+ git+https://github.com/facebookresearch/segment-anything.git