gaunernst commited on
Commit
97cd144
·
verified ·
1 Parent(s): 06a94fa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import cv2
4
+ import gradio as gr
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+ from PIL import Image, ImageOps
8
+
9
+ MODEL_PATH = "model.onnx"
10
+ IMAGE_SIZE = 480
11
+
12
+ SESSION = ort.InferenceSession(MODEL_PATH)
13
+ INPUT_NAME = SESSION.get_inputs()[0].name
14
+
15
+
16
+ def preprocess(img: Image.Image) -> np.ndarray:
17
+ resized_img = ImageOps.pad(img, (IMAGE_SIZE, IMAGE_SIZE), centering=(0, 0))
18
+ img_chw = np.array(resized_img).transpose(2, 0, 1).astype(np.float32) / 255
19
+ img_chw = (img_chw - 0.5) / 0.5
20
+ return img_chw
21
+
22
+
23
+ def distance(p1, p2):
24
+ return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5
25
+
26
+
27
+ # https://stackoverflow.com/a/1222855
28
+ # https://www.microsoft.com/en-us/research/wp-content/uploads/2016/11/Digital-Signal-Processing.pdf
29
+ def get_aspect_ratio_zhang(keypoints: np.ndarray, img_width: int, img_height: int):
30
+ keypoints = keypoints[[3, 2, 0, 1]] # re-arrange keypoint according to Zhang 2006 Figure 6
31
+ keypoints = np.concatenate([keypoints, np.ones((4, 1))], axis=1) # convert to homogeneous coordinates
32
+
33
+ # equation (11) and (12)
34
+ k2 = np.cross(keypoints[0], keypoints[3]).dot(keypoints[2]) / np.cross(keypoints[1], keypoints[3]).dot(keypoints[2])
35
+ k3 = np.cross(keypoints[0], keypoints[3]).dot(keypoints[1]) / np.cross(keypoints[2], keypoints[3]).dot(keypoints[1])
36
+
37
+ # equation (14) and (16)
38
+ n2 = k2 * keypoints[1] - keypoints[0]
39
+ n3 = k3 * keypoints[2] - keypoints[0]
40
+
41
+ # equation (21)
42
+ u0 = img_width / 2
43
+ v0 = img_height / 2
44
+ f2 = -(n2[0] * n3[0] - (n2[0] * n3[2] + n2[2] + n3[0]) * u0 + n2[2] * n3[2] * u0 * u0) / (n2[2] * n3[2]) + (
45
+ n2[1] * n3[1] - (n2[1] * n3[2] + n2[2] * n3[1]) * v0 + n2[2] * n3[2] * v0 * v0
46
+ )
47
+ f = math.sqrt(f2)
48
+
49
+ # equation (20)
50
+ A = np.array([[f, 0, u0], [0, f, v0], [0, 0, 1]])
51
+ A_inv = np.linalg.inv(A)
52
+ mid = A_inv.T.dot(A_inv)
53
+ wh_ratio2 = n2.dot(mid).dot(n2) / n3.dot(mid).dot(n3)
54
+
55
+ return math.sqrt(wh_ratio2)
56
+
57
+
58
+ def rectify(img_np: np.ndarray, keypoints: np.ndarray):
59
+ img_height, img_width = img_np.shape[:2]
60
+
61
+ h1 = distance(keypoints[0], keypoints[3])
62
+ h2 = distance(keypoints[1], keypoints[2])
63
+ h = (h1 + h2) * 0.5
64
+
65
+ # this may fail if two lines are parallel
66
+ try:
67
+ wh_ratio = get_aspect_ratio_zhang(keypoints, img_width, img_height)
68
+ w = h * wh_ratio
69
+
70
+ except:
71
+ print("Failed to estimate aspect ratio from perspective")
72
+ w1 = distance(keypoints[0], keypoints[1])
73
+ w2 = distance(keypoints[3], keypoints[2])
74
+ w = (w1 + w2) * 0.5
75
+
76
+ target_kpts = np.array([[1, 1], [w + 1, 1], [w + 1, h + 1], [1, h + 1]], dtype=np.float32)
77
+ transform = cv2.getPerspectiveTransform(keypoints, target_kpts)
78
+ cropped = cv2.warpPerspective(img_np, transform, (round(w) + 2, round(h) + 2), flags=cv2.INTER_CUBIC)
79
+ return cropped
80
+
81
+
82
+ def predict(img: Image.Image):
83
+ img_chw = preprocess(img)
84
+
85
+ pred_kpts = SESSION.run(None, {INPUT_NAME: img_chw[None]})[0][0]
86
+ kpts_xy = pred_kpts[:, :2] * max(img.size) / IMAGE_SIZE
87
+
88
+ img_np = np.array(img)
89
+ cv2.polylines(img_np, [kpts_xy.astype(int)], True, (0, 255, 0), thickness=5, lineType=cv2.LINE_AA)
90
+
91
+ if (pred_kpts[:, 2] >= 0.25).all():
92
+ cropped = rectify(np.array(img), kpts_xy)
93
+ else:
94
+ cropped = None
95
+
96
+ return cropped, img_np
97
+
98
+
99
+ gr.Interface(
100
+ predict,
101
+ inputs=[gr.Image(type="pil")],
102
+ outputs=["image", "image"],
103
+ ).launch(server_name="0.0.0.0")