Henry Scheible commited on
Commit
15ec046
·
1 Parent(s): 26f78d3

add examples

Browse files
Files changed (2) hide show
  1. app.py +6 -4
  2. examples/new_blank_image.png +0 -0
app.py CHANGED
@@ -30,7 +30,8 @@ print("Loading resnet...")
30
  model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
31
  hidden_state_size = model.fc.in_features
32
  model.fc = torch.nn.Linear(in_features=hidden_state_size, out_features=2, bias=True)
33
- model.load_state_dict(torch.load("model_best_epoch_4_59.62.pth"))
 
34
  model.to("cuda")
35
 
36
  import gradio as gr
@@ -49,7 +50,7 @@ def count_barnacles(input_img, progress=gr.Progress()):
49
  predicted_labels = torch.cat(predicted_labels_list)
50
  x = int(math.sqrt(predicted_labels.shape[0]))
51
  predicted_labels = predicted_labels.reshape([x, x, 2]).detach()
52
- label_img = predicted_labels[:, :, :1].cpu().numpy()
53
  label_img -= label_img.min()
54
  label_img /= label_img.max()
55
  label_img = (label_img * 255).astype(np.uint8)
@@ -78,9 +79,10 @@ def count_barnacles(input_img, progress=gr.Progress()):
78
  blank_img_copy = input_img.copy()
79
  for x, y in points:
80
  blank_img_copy = cv2.circle(blank_img_copy, (x, y), radius=4, color=(255, 0, 0), thickness=-1)
81
- return blank_img_copy, len(list(points))
82
 
83
 
84
  demo = gr.Interface(count_barnacles, gr.Image(shape=(500, 500), type="numpy"),
85
- outputs=[gr.Image(type="numpy"), "number"])
 
86
  demo.queue(concurrency_count=10).launch()
 
30
  model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
31
  hidden_state_size = model.fc.in_features
32
  model.fc = torch.nn.Linear(in_features=hidden_state_size, out_features=2, bias=True)
33
+ model.to("cuda")
34
+ model.load_state_dict(torch.load("model_best_epoch_4_59.62.pth", map_location=torch.device("cuda")))
35
  model.to("cuda")
36
 
37
  import gradio as gr
 
50
  predicted_labels = torch.cat(predicted_labels_list)
51
  x = int(math.sqrt(predicted_labels.shape[0]))
52
  predicted_labels = predicted_labels.reshape([x, x, 2]).detach()
53
+ label_img = predicted_labels[:, :, :1].cuda().numpy()
54
  label_img -= label_img.min()
55
  label_img /= label_img.max()
56
  label_img = (label_img * 255).astype(np.uint8)
 
79
  blank_img_copy = input_img.copy()
80
  for x, y in points:
81
  blank_img_copy = cv2.circle(blank_img_copy, (x, y), radius=4, color=(255, 0, 0), thickness=-1)
82
+ return blank_img_copy, int(len(list(points)))
83
 
84
 
85
  demo = gr.Interface(count_barnacles, gr.Image(shape=(500, 500), type="numpy"),
86
+ outputs=["image", "number"],
87
+ examples="examples")
88
  demo.queue(concurrency_count=10).launch()
examples/new_blank_image.png ADDED