Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import onnxruntime as ort
|
7 |
+
from PIL import Image, ImageOps
|
8 |
+
|
9 |
+
MODEL_PATH = "model.onnx"
|
10 |
+
IMAGE_SIZE = 480
|
11 |
+
|
12 |
+
SESSION = ort.InferenceSession(MODEL_PATH)
|
13 |
+
INPUT_NAME = SESSION.get_inputs()[0].name
|
14 |
+
|
15 |
+
|
16 |
+
def preprocess(img: Image.Image) -> np.ndarray:
|
17 |
+
resized_img = ImageOps.pad(img, (IMAGE_SIZE, IMAGE_SIZE), centering=(0, 0))
|
18 |
+
img_chw = np.array(resized_img).transpose(2, 0, 1).astype(np.float32) / 255
|
19 |
+
img_chw = (img_chw - 0.5) / 0.5
|
20 |
+
return img_chw
|
21 |
+
|
22 |
+
|
23 |
+
def distance(p1, p2):
|
24 |
+
return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5
|
25 |
+
|
26 |
+
|
27 |
+
# https://stackoverflow.com/a/1222855
|
28 |
+
# https://www.microsoft.com/en-us/research/wp-content/uploads/2016/11/Digital-Signal-Processing.pdf
|
29 |
+
def get_aspect_ratio_zhang(keypoints: np.ndarray, img_width: int, img_height: int):
|
30 |
+
keypoints = keypoints[[3, 2, 0, 1]] # re-arrange keypoint according to Zhang 2006 Figure 6
|
31 |
+
keypoints = np.concatenate([keypoints, np.ones((4, 1))], axis=1) # convert to homogeneous coordinates
|
32 |
+
|
33 |
+
# equation (11) and (12)
|
34 |
+
k2 = np.cross(keypoints[0], keypoints[3]).dot(keypoints[2]) / np.cross(keypoints[1], keypoints[3]).dot(keypoints[2])
|
35 |
+
k3 = np.cross(keypoints[0], keypoints[3]).dot(keypoints[1]) / np.cross(keypoints[2], keypoints[3]).dot(keypoints[1])
|
36 |
+
|
37 |
+
# equation (14) and (16)
|
38 |
+
n2 = k2 * keypoints[1] - keypoints[0]
|
39 |
+
n3 = k3 * keypoints[2] - keypoints[0]
|
40 |
+
|
41 |
+
# equation (21)
|
42 |
+
u0 = img_width / 2
|
43 |
+
v0 = img_height / 2
|
44 |
+
f2 = -(n2[0] * n3[0] - (n2[0] * n3[2] + n2[2] + n3[0]) * u0 + n2[2] * n3[2] * u0 * u0) / (n2[2] * n3[2]) + (
|
45 |
+
n2[1] * n3[1] - (n2[1] * n3[2] + n2[2] * n3[1]) * v0 + n2[2] * n3[2] * v0 * v0
|
46 |
+
)
|
47 |
+
f = math.sqrt(f2)
|
48 |
+
|
49 |
+
# equation (20)
|
50 |
+
A = np.array([[f, 0, u0], [0, f, v0], [0, 0, 1]])
|
51 |
+
A_inv = np.linalg.inv(A)
|
52 |
+
mid = A_inv.T.dot(A_inv)
|
53 |
+
wh_ratio2 = n2.dot(mid).dot(n2) / n3.dot(mid).dot(n3)
|
54 |
+
|
55 |
+
return math.sqrt(wh_ratio2)
|
56 |
+
|
57 |
+
|
58 |
+
def rectify(img_np: np.ndarray, keypoints: np.ndarray):
|
59 |
+
img_height, img_width = img_np.shape[:2]
|
60 |
+
|
61 |
+
h1 = distance(keypoints[0], keypoints[3])
|
62 |
+
h2 = distance(keypoints[1], keypoints[2])
|
63 |
+
h = (h1 + h2) * 0.5
|
64 |
+
|
65 |
+
# this may fail if two lines are parallel
|
66 |
+
try:
|
67 |
+
wh_ratio = get_aspect_ratio_zhang(keypoints, img_width, img_height)
|
68 |
+
w = h * wh_ratio
|
69 |
+
|
70 |
+
except:
|
71 |
+
print("Failed to estimate aspect ratio from perspective")
|
72 |
+
w1 = distance(keypoints[0], keypoints[1])
|
73 |
+
w2 = distance(keypoints[3], keypoints[2])
|
74 |
+
w = (w1 + w2) * 0.5
|
75 |
+
|
76 |
+
target_kpts = np.array([[1, 1], [w + 1, 1], [w + 1, h + 1], [1, h + 1]], dtype=np.float32)
|
77 |
+
transform = cv2.getPerspectiveTransform(keypoints, target_kpts)
|
78 |
+
cropped = cv2.warpPerspective(img_np, transform, (round(w) + 2, round(h) + 2), flags=cv2.INTER_CUBIC)
|
79 |
+
return cropped
|
80 |
+
|
81 |
+
|
82 |
+
def predict(img: Image.Image):
|
83 |
+
img_chw = preprocess(img)
|
84 |
+
|
85 |
+
pred_kpts = SESSION.run(None, {INPUT_NAME: img_chw[None]})[0][0]
|
86 |
+
kpts_xy = pred_kpts[:, :2] * max(img.size) / IMAGE_SIZE
|
87 |
+
|
88 |
+
img_np = np.array(img)
|
89 |
+
cv2.polylines(img_np, [kpts_xy.astype(int)], True, (0, 255, 0), thickness=5, lineType=cv2.LINE_AA)
|
90 |
+
|
91 |
+
if (pred_kpts[:, 2] >= 0.25).all():
|
92 |
+
cropped = rectify(np.array(img), kpts_xy)
|
93 |
+
else:
|
94 |
+
cropped = None
|
95 |
+
|
96 |
+
return cropped, img_np
|
97 |
+
|
98 |
+
|
99 |
+
gr.Interface(
|
100 |
+
predict,
|
101 |
+
inputs=[gr.Image(type="pil")],
|
102 |
+
outputs=["image", "image"],
|
103 |
+
).launch(server_name="0.0.0.0")
|