Spaces:
Runtime error
Runtime error
File size: 8,933 Bytes
5afcf8b 62aee00 5afcf8b 62aee00 fc1a0c8 75541cb 5afcf8b 75541cb fc1a0c8 5afcf8b fc1a0c8 75541cb 5afcf8b fc1a0c8 5afcf8b 62aee00 5afcf8b 75541cb 62aee00 d300f4f 62aee00 315bd25 3440d36 55068d3 62aee00 5afcf8b fc1a0c8 5afcf8b 62aee00 5afcf8b a61a980 5afcf8b 75541cb fc1a0c8 5afcf8b fc1a0c8 5afcf8b fc1a0c8 5afcf8b 8fb39d0 5afcf8b fc1a0c8 8fb39d0 fc1a0c8 5afcf8b 75541cb a96a3c9 fc1a0c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import cv2
import numpy as np
import math
import torch
import random
from torch.utils.data import DataLoader
from torchvision.transforms import Resize
torch.manual_seed(12345)
random.seed(12345)
np.random.seed(12345)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def find_contours(img, color):
low = color - 10
high = color + 10
mask = cv2.inRange(img, low, high)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
print(f"Total Contours: {len(contours)}")
nonempty_contours = list()
for i in range(len(contours)):
if hierarchy[0,i,3] == -1 and cv2.contourArea(contours[i]) > cv2.arcLength(contours[i], True):
nonempty_contours += [contours[i]]
print(f"Nonempty Contours: {len(nonempty_contours)}")
contour_plot = img.copy()
contour_plot = cv2.drawContours(contour_plot, nonempty_contours, -1, (0,255,0), -1)
sorted_contours = sorted(nonempty_contours, key=cv2.contourArea, reverse= True)
bounding_rects = [cv2.boundingRect(cnt) for cnt in contours]
for (i,c) in enumerate(sorted_contours):
M= cv2.moments(c)
cx= int(M['m10']/M['m00'])
cy= int(M['m01']/M['m00'])
cv2.putText(contour_plot, text= str(i), org=(cx,cy),
fontFace= cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.25, color=(255,255,255),
thickness=1, lineType=cv2.LINE_AA)
N = len(sorted_contours)
H, W, C = img.shape
boxes_array_xywh = [cv2.boundingRect(cnt) for cnt in sorted_contours]
boxes_array_corners = [[x, y, x+w, y+h] for x, y, w, h in boxes_array_xywh]
boxes = torch.tensor(boxes_array_corners)
labels = torch.ones(N)
masks = np.zeros([N, H, W])
for idx in range(len(sorted_contours)):
cnt = sorted_contours[idx]
cv2.drawContours(masks[idx,:,:], [cnt], 0, (255), -1)
masks = masks / 255.0
masks = torch.tensor(masks)
# for box in boxes:
# cv2.rectangle(contour_plot, (box[0].item(), box[1].item()), (box[2].item(), box[3].item()), (255,0,0), 2)
return contour_plot, (boxes, masks)
def get_dataset_x(blank_image, filter_size=50, filter_stride=2):
full_image_tensor = torch.tensor(blank_image).type(torch.FloatTensor).permute(2, 0, 1).unsqueeze(0)
num_windows_h = math.floor((full_image_tensor.shape[2] - filter_size) / filter_stride) + 1
num_windows_w = math.floor((full_image_tensor.shape[3] - filter_size) / filter_stride) + 1
windows = torch.nn.functional.unfold(full_image_tensor, (filter_size, filter_size), stride=filter_stride).reshape(
[1, 3, 50, 50, num_windows_h * num_windows_w]).permute([0, 4, 1, 2, 3]).squeeze()
dataset_images = [windows[idx] for idx in range(len(windows))]
dataset = list(dataset_images)
return dataset
def get_dataset(labeled_image, blank_image, color, filter_size=50, filter_stride=2, label_size=5):
contour_plot, (blue_boxes, blue_masks) = find_contours(labeled_image, color)
mask = torch.sum(blue_masks, 0)
label_dim = int((labeled_image.shape[0] - filter_size) / filter_stride + 1)
labels = torch.zeros(label_dim, label_dim)
mask_labels = torch.zeros(label_dim, label_dim, filter_size, filter_size)
for lx in range(label_dim):
for ly in range(label_dim):
mask_labels[lx, ly, :, :] = mask[
lx * filter_stride: lx * filter_stride + filter_size,
ly * filter_stride: ly * filter_stride + filter_size
]
print(labels.shape)
for box in blue_boxes:
x = int((box[0] + box[2]) / 2)
y = int((box[1] + box[3]) / 2)
window_x = int((x - int(filter_size / 2)) / filter_stride)
window_y = int((y - int(filter_size / 2)) / filter_stride)
clamp = lambda n, minn, maxn: max(min(maxn, n), minn)
labels[
clamp(window_y - label_size, 0, labels.shape[0] - 1):clamp(window_y + label_size, 0, labels.shape[0] - 1),
clamp(window_x - label_size, 0, labels.shape[0] - 1):clamp(window_x + label_size, 0, labels.shape[0] - 1),
] = 1
positive_labels = labels.flatten() / labels.max()
negative_labels = 1 - positive_labels
pos_mask_labels = torch.flatten(mask_labels, end_dim=1)
neg_mask_labels = 1 - pos_mask_labels
mask_labels = torch.stack([pos_mask_labels, neg_mask_labels], dim=1)
dataset_labels = torch.tensor(list(zip(positive_labels, negative_labels)))
dataset = list(zip(
get_dataset_x(blank_image, filter_size=filter_size, filter_stride=filter_stride),
dataset_labels,
mask_labels
))
return dataset, (labels, mask_labels)
from torchvision.models.resnet import resnet50
from torchvision.models.resnet import ResNet50_Weights
print("Loading resnet...")
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
hidden_state_size = model.fc.in_features
model.fc = torch.nn.Linear(in_features=hidden_state_size, out_features=2, bias=True)
model.to(device)
model.load_state_dict(torch.load("model_best_epoch_4_59.62.pth", map_location=torch.device(device)))
model.to(device)
import gradio as gr
def count_barnacles(raw_input_img, labeled_input_img, progress=gr.Progress()):
progress(0, desc="Finding bounding wire")
# crop image
h, w = raw_input_img.shape[:2]
imghsv = cv2.cvtColor(raw_input_img, cv2.COLOR_RGB2HSV)
hsvblur = cv2.GaussianBlur(imghsv, (9, 9), 0)
lower = np.array([70, 20, 20])
upper = np.array([130, 255, 255])
color_mask = cv2.inRange(hsvblur, lower, upper)
invert = cv2.bitwise_not(color_mask)
contours, _ = cv2.findContours(invert, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
max_contour = contours[0]
largest_area = 0
for index, contour in enumerate(contours):
area = cv2.contourArea(contour)
if area > largest_area:
if cv2.pointPolygonTest(contour, (w / 2, h / 2), False) == 1:
largest_area = area
max_contour = contour
x, y, w, h = cv2.boundingRect(max_contour)
progress(0, desc="Resizing Image")
cropped_img = raw_input_img[x:x+w, y:y+h]
cropped_image_tensor = torch.transpose(torch.tensor(cropped_img).to(device), 0, 2)
resize = Resize((1500, 1500))
input_img = cropped_image_tensor
blank_img_copy = torch.transpose(input_img, 0, 2).to("cpu").detach().numpy().copy()
progress(0, desc="Generating Windows")
test_dataset = get_dataset_x(input_img)
test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
model.eval()
predicted_labels_list = []
for data in progress.tqdm(test_dataloader):
with torch.no_grad():
data = data.to(device)
predicted_labels_list += [model(data)]
predicted_labels = torch.cat(predicted_labels_list)
x = int(math.sqrt(predicted_labels.shape[0]))
predicted_labels = predicted_labels.reshape([x, x, 2]).detach()
label_img = predicted_labels[:, :, :1].cpu().numpy()
label_img -= label_img.min()
label_img /= label_img.max()
label_img = (label_img * 255).astype(np.uint8)
mask = np.array(label_img > 180, np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\
gt_contours = find_contours(labeled_input_img[x:x+w, y:y+h], cropped_img, np.array([59, 76, 160]))
def extract_contour_center(cnt):
M = cv2.moments(cnt)
cx = int(M['m10'] / M['m00'])
cy = int(M['m01'] / M['m00'])
return cx, cy
filter_width = 50
filter_stride = 2
def rev_window_transform(point):
wx, wy = point
x = int(filter_width / 2) + wx * filter_stride
y = int(filter_width / 2) + wy * filter_stride
return x, y
nonempty_contours = filter(lambda cnt: cv2.contourArea(cnt) != 0, contours)
windows = map(extract_contour_center, nonempty_contours)
points = list(map(rev_window_transform, windows))
for x, y in points:
blank_img_copy = cv2.circle(blank_img_copy, (x, y), radius=4, color=(255, 0, 0), thickness=-1)
print(f"pointlist: {len(points)}")
return blank_img_copy, len(points)
demo = gr.Interface(count_barnacles,
inputs=[
gr.Image(shape=(500, 500), type="numpy", label="Input Image"),
gr.Image(shape=(500, 500), type="numpy", label="Masked Input Image")
],
outputs=[
gr.Image(shape=(500, 500), type="numpy", label="Annotated Image"),
gr.Number(label="Predicted Number of Barnacles"),
gr.Number(label="Actual Number of Barnacles"),
gr.Number(label="Custom Metric")
])
# examples="examples")
demo.queue(concurrency_count=10).launch()
|