Henry Scheible commited on
Commit
75541cb
·
1 Parent(s): 315bd25

resize to (1500, 1500)

Browse files
Files changed (2) hide show
  1. .idea/.gitignore +8 -0
  2. app.py +111 -5
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
app.py CHANGED
@@ -13,8 +13,56 @@ np.random.seed(12345)
13
 
14
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def get_dataset_x(blank_image, filter_size=50, filter_stride=2):
17
- full_image_tensor = blank_image.type(torch.FloatTensor).unsqueeze(0)
18
  num_windows_h = math.floor((full_image_tensor.shape[2] - filter_size) / filter_stride) + 1
19
  num_windows_w = math.floor((full_image_tensor.shape[3] - filter_size) / filter_stride) + 1
20
  windows = torch.nn.functional.unfold(full_image_tensor, (filter_size, filter_size), stride=filter_stride).reshape(
@@ -25,6 +73,51 @@ def get_dataset_x(blank_image, filter_size=50, filter_stride=2):
25
  return dataset
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  from torchvision.models.resnet import resnet50
29
  from torchvision.models.resnet import ResNet50_Weights
30
 
@@ -39,7 +132,7 @@ model.to(device)
39
  import gradio as gr
40
 
41
 
42
- def count_barnacles(raw_input_img, progress=gr.Progress()):
43
  progress(0, desc="Finding bounding wire")
44
 
45
  # crop image
@@ -91,7 +184,11 @@ def count_barnacles(raw_input_img, progress=gr.Progress()):
91
  label_img /= label_img.max()
92
  label_img = (label_img * 255).astype(np.uint8)
93
  mask = np.array(label_img > 180, np.uint8)
94
- contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
 
 
 
 
95
 
96
  def extract_contour_center(cnt):
97
  M = cv2.moments(cnt)
@@ -117,7 +214,16 @@ def count_barnacles(raw_input_img, progress=gr.Progress()):
117
  return blank_img_copy, len(points)
118
 
119
 
120
- demo = gr.Interface(count_barnacles, gr.Image(shape=(500, 500), type="numpy"),
121
- outputs=[gr.Image(shape=(500, 500), type="numpy", label="Annotated Image"), gr.Number(label="Number of Barnacles")])
 
 
 
 
 
 
 
 
 
122
  # examples="examples")
123
  demo.queue(concurrency_count=10).launch()
 
13
 
14
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
15
 
16
+ def find_contours(img, color):
17
+ low = color - 10
18
+ high = color + 10
19
+
20
+ mask = cv2.inRange(img, low, high)
21
+ contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
22
+
23
+ print(f"Total Contours: {len(contours)}")
24
+ nonempty_contours = list()
25
+ for i in range(len(contours)):
26
+ if hierarchy[0,i,3] == -1 and cv2.contourArea(contours[i]) > cv2.arcLength(contours[i], True):
27
+ nonempty_contours += [contours[i]]
28
+ print(f"Nonempty Contours: {len(nonempty_contours)}")
29
+ contour_plot = img.copy()
30
+ contour_plot = cv2.drawContours(contour_plot, nonempty_contours, -1, (0,255,0), -1)
31
+
32
+ sorted_contours = sorted(nonempty_contours, key=cv2.contourArea, reverse= True)
33
+
34
+ bounding_rects = [cv2.boundingRect(cnt) for cnt in contours]
35
+
36
+ for (i,c) in enumerate(sorted_contours):
37
+ M= cv2.moments(c)
38
+ cx= int(M['m10']/M['m00'])
39
+ cy= int(M['m01']/M['m00'])
40
+ cv2.putText(contour_plot, text= str(i), org=(cx,cy),
41
+ fontFace= cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.25, color=(255,255,255),
42
+ thickness=1, lineType=cv2.LINE_AA)
43
+
44
+ N = len(sorted_contours)
45
+ H, W, C = img.shape
46
+ boxes_array_xywh = [cv2.boundingRect(cnt) for cnt in sorted_contours]
47
+ boxes_array_corners = [[x, y, x+w, y+h] for x, y, w, h in boxes_array_xywh]
48
+ boxes = torch.tensor(boxes_array_corners)
49
+
50
+ labels = torch.ones(N)
51
+ masks = np.zeros([N, H, W])
52
+ for idx in range(len(sorted_contours)):
53
+ cnt = sorted_contours[idx]
54
+ cv2.drawContours(masks[idx,:,:], [cnt], 0, (255), -1)
55
+ masks = masks / 255.0
56
+ masks = torch.tensor(masks)
57
+
58
+ # for box in boxes:
59
+ # cv2.rectangle(contour_plot, (box[0].item(), box[1].item()), (box[2].item(), box[3].item()), (255,0,0), 2)
60
+
61
+ return contour_plot, (boxes, masks)
62
+
63
+
64
  def get_dataset_x(blank_image, filter_size=50, filter_stride=2):
65
+ full_image_tensor = torch.tensor(blank_image).type(torch.FloatTensor).permute(2, 0, 1).unsqueeze(0)
66
  num_windows_h = math.floor((full_image_tensor.shape[2] - filter_size) / filter_stride) + 1
67
  num_windows_w = math.floor((full_image_tensor.shape[3] - filter_size) / filter_stride) + 1
68
  windows = torch.nn.functional.unfold(full_image_tensor, (filter_size, filter_size), stride=filter_stride).reshape(
 
73
  return dataset
74
 
75
 
76
+ def get_dataset(labeled_image, blank_image, color, filter_size=50, filter_stride=2, label_size=5):
77
+ contour_plot, (blue_boxes, blue_masks) = find_contours(labeled_image, color)
78
+
79
+ mask = torch.sum(blue_masks, 0)
80
+
81
+ label_dim = int((labeled_image.shape[0] - filter_size) / filter_stride + 1)
82
+ labels = torch.zeros(label_dim, label_dim)
83
+ mask_labels = torch.zeros(label_dim, label_dim, filter_size, filter_size)
84
+
85
+ for lx in range(label_dim):
86
+ for ly in range(label_dim):
87
+ mask_labels[lx, ly, :, :] = mask[
88
+ lx * filter_stride: lx * filter_stride + filter_size,
89
+ ly * filter_stride: ly * filter_stride + filter_size
90
+ ]
91
+
92
+ print(labels.shape)
93
+ for box in blue_boxes:
94
+ x = int((box[0] + box[2]) / 2)
95
+ y = int((box[1] + box[3]) / 2)
96
+
97
+ window_x = int((x - int(filter_size / 2)) / filter_stride)
98
+ window_y = int((y - int(filter_size / 2)) / filter_stride)
99
+
100
+ clamp = lambda n, minn, maxn: max(min(maxn, n), minn)
101
+
102
+ labels[
103
+ clamp(window_y - label_size, 0, labels.shape[0] - 1):clamp(window_y + label_size, 0, labels.shape[0] - 1),
104
+ clamp(window_x - label_size, 0, labels.shape[0] - 1):clamp(window_x + label_size, 0, labels.shape[0] - 1),
105
+ ] = 1
106
+
107
+ positive_labels = labels.flatten() / labels.max()
108
+ negative_labels = 1 - positive_labels
109
+ pos_mask_labels = torch.flatten(mask_labels, end_dim=1)
110
+ neg_mask_labels = 1 - pos_mask_labels
111
+ mask_labels = torch.stack([pos_mask_labels, neg_mask_labels], dim=1)
112
+ dataset_labels = torch.tensor(list(zip(positive_labels, negative_labels)))
113
+ dataset = list(zip(
114
+ get_dataset_x(blank_image, filter_size=filter_size, filter_stride=filter_stride),
115
+ dataset_labels,
116
+ mask_labels
117
+ ))
118
+ return dataset, (labels, mask_labels)
119
+
120
+
121
  from torchvision.models.resnet import resnet50
122
  from torchvision.models.resnet import ResNet50_Weights
123
 
 
132
  import gradio as gr
133
 
134
 
135
+ def count_barnacles(raw_input_img, labeled_input_img, progress=gr.Progress()):
136
  progress(0, desc="Finding bounding wire")
137
 
138
  # crop image
 
184
  label_img /= label_img.max()
185
  label_img = (label_img * 255).astype(np.uint8)
186
  mask = np.array(label_img > 180, np.uint8)
187
+ contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\
188
+
189
+ gt_contours = find_contours(labeled_input_img[x:x+w, y:y+h], cropped_img, np.array([59, 76, 160]))
190
+
191
+
192
 
193
  def extract_contour_center(cnt):
194
  M = cv2.moments(cnt)
 
214
  return blank_img_copy, len(points)
215
 
216
 
217
+ demo = gr.Interface(count_barnacles,
218
+ inputs=[
219
+ gr.Image(shape=(500, 500), type="numpy", label="Input Image"),
220
+ gr.Image(shape=(500, 500), type="numpy", label="Masked Input Image")
221
+ ],
222
+ outputs=[
223
+ gr.Image(shape=(500, 500), type="numpy", label="Annotated Image"),
224
+ gr.Number(label="Predicted Number of Barnacles"),
225
+ gr.Number(label="Actual Number of Barnacles"),
226
+ gr.Number(label="Custom Metric")
227
+ ])
228
  # examples="examples")
229
  demo.queue(concurrency_count=10).launch()