Spaces:
Sleeping
Sleeping
HERIUN
commited on
Commit
ยท
47a9de9
1
Parent(s):
6a0c0c0
demo.py
Browse files
demo.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from PIL import Image, ImageDraw
|
7 |
+
from rect_main import docscanner_rec, load_docscanner_model
|
8 |
+
from data_utils.image_utils import unwarp, mask2point, get_corner, _rotate_90_degrees
|
9 |
+
from config import Config
|
10 |
+
|
11 |
+
|
12 |
+
config = Config()
|
13 |
+
cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
|
15 |
+
|
16 |
+
docscanner = load_docscanner_model(
|
17 |
+
cuda, path_l=config.get_rec_model_path, path_m=config.get_seg_model_path
|
18 |
+
)
|
19 |
+
|
20 |
+
# ์ขํ๋ฅผ ์ด๊ธฐํํ๋ ํจ์
|
21 |
+
def reset_points(image, state):
|
22 |
+
state = []
|
23 |
+
return image, state
|
24 |
+
|
25 |
+
def cutting_image(image, state):
|
26 |
+
min_x = min(point[0] for point in state)
|
27 |
+
max_x = max(point[0] for point in state)
|
28 |
+
min_y = min(point[1] for point in state)
|
29 |
+
max_y = max(point[1] for point in state)
|
30 |
+
|
31 |
+
cutted_image = image[min_y:max_y, min_x:max_x]
|
32 |
+
state = []
|
33 |
+
return cutted_image, cutted_image, state
|
34 |
+
|
35 |
+
def rotate_image(image):
|
36 |
+
rotated_image = _rotate_90_degrees(image)
|
37 |
+
state = []
|
38 |
+
return rotated_image, state
|
39 |
+
|
40 |
+
def reset_image(image, state):
|
41 |
+
out_image, msk_np = docscanner_rec(image, docscanner)
|
42 |
+
state = list(get_corner(mask2point(mask=msk_np)))
|
43 |
+
|
44 |
+
img = Image.fromarray(image)
|
45 |
+
area = image.shape[0]*image.shape[1]
|
46 |
+
radius=max(5, round(area**0.5 / 120))
|
47 |
+
# ์ขํ๊ฐ ์ต์ 3๊ฐ ์ด์์ผ ๋๋ง ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
|
48 |
+
draw = ImageDraw.Draw(img)
|
49 |
+
for pt in state:
|
50 |
+
left_up_point = (pt[0] - radius, pt[1] - radius)
|
51 |
+
right_down_point = (pt[0] + radius, pt[1] + radius)
|
52 |
+
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
|
53 |
+
|
54 |
+
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
|
55 |
+
# ๊ฐ๋์ ๋ฐ๋ผ ์ ๋ค์ ์ ๋ ฌ
|
56 |
+
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
|
57 |
+
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
|
58 |
+
|
59 |
+
return img, state
|
60 |
+
|
61 |
+
def auto_point_detect(image):
|
62 |
+
out_image, msk_np = docscanner_rec(image, docscanner)
|
63 |
+
state = list(get_corner(mask2point(mask=msk_np)))
|
64 |
+
|
65 |
+
img = Image.fromarray(image)
|
66 |
+
area = image.shape[0]*image.shape[1]
|
67 |
+
radius=max(5, round(area**0.5 / 120))
|
68 |
+
# ์ขํ๊ฐ ์ต์ 3๊ฐ ์ด์์ผ ๋๋ง ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
|
69 |
+
draw = ImageDraw.Draw(img)
|
70 |
+
for pt in state:
|
71 |
+
left_up_point = (pt[0] - radius, pt[1] - radius)
|
72 |
+
right_down_point = (pt[0] + radius, pt[1] + radius)
|
73 |
+
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
|
74 |
+
|
75 |
+
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
|
76 |
+
# ๊ฐ๋์ ๋ฐ๋ผ ์ ๋ค์ ์ ๋ ฌ
|
77 |
+
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
|
78 |
+
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
|
79 |
+
|
80 |
+
return img, state
|
81 |
+
|
82 |
+
|
83 |
+
def calculate_angle(point, center):
|
84 |
+
return math.atan2(point[1] - center[1], point[0] - center[0])
|
85 |
+
|
86 |
+
# ์ขํ๋ฅผ ๋ฐ์์ ํด๋ฆฌ๊ณค์ ๊ทธ๋ฆฌ๋ ํจ์
|
87 |
+
def draw_polygon_on_image(image, evt: gr.SelectData, state):
|
88 |
+
img = Image.fromarray(image)
|
89 |
+
|
90 |
+
pt = (evt.index[0], evt.index[1])
|
91 |
+
state.append(pt)
|
92 |
+
# ํด๋ฆญํ ์ขํ๋ฅผ ์ ์ฅ
|
93 |
+
area = image.shape[0]*image.shape[1]
|
94 |
+
radius=max(5, round(area**0.5 / 120))
|
95 |
+
|
96 |
+
draw = ImageDraw.Draw(img)
|
97 |
+
for pt in state:
|
98 |
+
left_up_point = (pt[0] - radius, pt[1] - radius)
|
99 |
+
right_down_point = (pt[0] + radius, pt[1] + radius)
|
100 |
+
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
|
101 |
+
|
102 |
+
if len(state) == 2:
|
103 |
+
draw.line([state[0], state[1]], fill="red", width=round(radius/2))
|
104 |
+
if len(state) >= 3: # ์ขํ๊ฐ ์ต์ 3๊ฐ ์ด์์ผ ๋๋ง ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
|
105 |
+
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
|
106 |
+
# ๊ฐ๋์ ๋ฐ๋ผ ์ ๋ค์ ์ ๋ ฌ
|
107 |
+
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
|
108 |
+
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
|
109 |
+
|
110 |
+
return img, state
|
111 |
+
|
112 |
+
def sort_corners(corners):
|
113 |
+
# ๊ฐ ์ขํ๋ฅผ (x, y) ํํ๋ก ๋ฐ๋๋ค๊ณ ๊ฐ์ ํฉ๋๋ค.
|
114 |
+
# corners = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
|
115 |
+
|
116 |
+
if len(corners) != 4:
|
117 |
+
raise ValueError("Input should contain exactly four coordinates.")
|
118 |
+
|
119 |
+
# ์ขํ๋ฅผ y ๊ธฐ์ค์ผ๋ก ์ ๋ ฌํ์ฌ ๊ฐ ์ขํ๋ฅผ ๊ฒฐ์ ํฉ๋๋ค.
|
120 |
+
sorted_by_y = sorted(corners, key=lambda p: p[1]) # y ๊ธฐ์ค์ผ๋ก ์ ๋ ฌ
|
121 |
+
lt, rt = sorted(sorted_by_y[:2], key=lambda p: p[0])
|
122 |
+
lb, rb = sorted(sorted_by_y[2:], key=lambda p: p[0])
|
123 |
+
|
124 |
+
return lt, rt, rb, lb
|
125 |
+
|
126 |
+
def convert(image, state):
|
127 |
+
h,w = image.shape[:2]
|
128 |
+
if len(state) < 4:
|
129 |
+
out_image, msk_np = docscanner_rec(image, docscanner)
|
130 |
+
out_image = out_image[:,:,::-1]
|
131 |
+
elif len(state) ==4:
|
132 |
+
state = list(sort_corners(state))
|
133 |
+
src = np.array(state).astype(np.float32)
|
134 |
+
dst = np.float32([
|
135 |
+
(0, 0),
|
136 |
+
(w - 1, 0),
|
137 |
+
(w - 1, h - 1),
|
138 |
+
(0, h - 1)
|
139 |
+
])
|
140 |
+
out_image, M = unwarp(image, src, dst)
|
141 |
+
return out_image
|
142 |
+
css = """
|
143 |
+
.image-container {
|
144 |
+
padding: 20px;
|
145 |
+
background-color: #f0f0f0;
|
146 |
+
}
|
147 |
+
"""
|
148 |
+
# Gradio Blocks ์ปจํ
์คํธ์์ ์ธํฐํ์ด์ค ๊ตฌ์ฑ
|
149 |
+
with gr.Blocks(css=css) as demo:
|
150 |
+
state = gr.State([])
|
151 |
+
with gr.Row():
|
152 |
+
with gr.Column():
|
153 |
+
text = gr.Textbox("์
๋ ฅ ์ด๋ฏธ์ง(์ฝ๋๋ฅผ ํด๋ฆญํ์ธ์)", show_label=False)
|
154 |
+
image_input = gr.Image(show_label=False, interactive=True, elem_classes="image-container")
|
155 |
+
clear_button = gr.Button("Clear Points")
|
156 |
+
cutting_button = gr.Button("Cutting Image(need more than 2 points)")
|
157 |
+
rotating_button = gr.Button("Rotate Image(clock wise 90 degree)")
|
158 |
+
auto_button = gr.Button("Auto Points detection")
|
159 |
+
convert_button = gr.Button("Convert Image")
|
160 |
+
with gr.Column():
|
161 |
+
text = gr.Textbox("๋ณํ๋ ์์ญ", show_label=False)
|
162 |
+
image_output = gr.Image(show_label=False)
|
163 |
+
# state_display = gr.Textbox(label="Current State")
|
164 |
+
# coordinates_text = gr.Textbox(label="Coordinates", placeholder="Enter coordinates (x, y) for each point")
|
165 |
+
# update_coords_button = gr.Button("Update Coordinates")
|
166 |
+
with gr.Column():
|
167 |
+
text = gr.Textbox("๊ฒฐ๊ณผ ์ด๋ฏธ์ง", show_label=False)
|
168 |
+
result_image = gr.Image(show_label=False, format="png")
|
169 |
+
|
170 |
+
# # ์ด๋ฏธ์ง ์์์ ํด๋ฆญ ์ด๋ฒคํธ ์ฒ๋ฆฌ
|
171 |
+
image_input.select(draw_polygon_on_image, inputs=[image_input,state], outputs=[image_output,state])
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
# ์ขํ ์ด๊ธฐํ ๋ฒํผ ํด๋ฆญ ์ ์ขํ ๋ฆฌ์
|
176 |
+
clear_button.click(fn=reset_points, inputs=[image_input,state], outputs=[image_output,state])
|
177 |
+
# ์ด๋ฏธ์ง ์๋ฅด๊ธฐ ํธ์ง
|
178 |
+
cutting_button.click(fn=cutting_image, inputs=[image_input,state], outputs=[image_input, image_output, state])
|
179 |
+
# ์ด๋ฏธ์ง ํ์
|
180 |
+
rotating_button.click(fn=rotate_image, inputs=[image_input], outputs=[image_input, state])
|
181 |
+
# ์๋ ๊ฒ์ถ ๋ฒํผ
|
182 |
+
auto_button.click(fn=auto_point_detect, inputs=image_input, outputs=[image_output,state])
|
183 |
+
# ๋ณํ ๋ฒํผ
|
184 |
+
convert_button.click(fn=convert, inputs=[image_input,state], outputs=result_image)
|
185 |
+
|
186 |
+
demo.launch(share=True)
|