HERIUN commited on
Commit
47a9de9
ยท
1 Parent(s): 6a0c0c0
Files changed (1) hide show
  1. demo.py +186 -0
demo.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import math
4
+ import torch
5
+
6
+ from PIL import Image, ImageDraw
7
+ from rect_main import docscanner_rec, load_docscanner_model
8
+ from data_utils.image_utils import unwarp, mask2point, get_corner, _rotate_90_degrees
9
+ from config import Config
10
+
11
+
12
+ config = Config()
13
+ cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+
16
+ docscanner = load_docscanner_model(
17
+ cuda, path_l=config.get_rec_model_path, path_m=config.get_seg_model_path
18
+ )
19
+
20
+ # ์ขŒํ‘œ๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜
21
+ def reset_points(image, state):
22
+ state = []
23
+ return image, state
24
+
25
+ def cutting_image(image, state):
26
+ min_x = min(point[0] for point in state)
27
+ max_x = max(point[0] for point in state)
28
+ min_y = min(point[1] for point in state)
29
+ max_y = max(point[1] for point in state)
30
+
31
+ cutted_image = image[min_y:max_y, min_x:max_x]
32
+ state = []
33
+ return cutted_image, cutted_image, state
34
+
35
+ def rotate_image(image):
36
+ rotated_image = _rotate_90_degrees(image)
37
+ state = []
38
+ return rotated_image, state
39
+
40
+ def reset_image(image, state):
41
+ out_image, msk_np = docscanner_rec(image, docscanner)
42
+ state = list(get_corner(mask2point(mask=msk_np)))
43
+
44
+ img = Image.fromarray(image)
45
+ area = image.shape[0]*image.shape[1]
46
+ radius=max(5, round(area**0.5 / 120))
47
+ # ์ขŒํ‘œ๊ฐ€ ์ตœ์†Œ 3๊ฐœ ์ด์ƒ์ผ ๋•Œ๋งŒ ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
48
+ draw = ImageDraw.Draw(img)
49
+ for pt in state:
50
+ left_up_point = (pt[0] - radius, pt[1] - radius)
51
+ right_down_point = (pt[0] + radius, pt[1] + radius)
52
+ draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
53
+
54
+ center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
55
+ # ๊ฐ๋„์— ๋”ฐ๋ผ ์ ๋“ค์„ ์ •๋ ฌ
56
+ sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
57
+ draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
58
+
59
+ return img, state
60
+
61
+ def auto_point_detect(image):
62
+ out_image, msk_np = docscanner_rec(image, docscanner)
63
+ state = list(get_corner(mask2point(mask=msk_np)))
64
+
65
+ img = Image.fromarray(image)
66
+ area = image.shape[0]*image.shape[1]
67
+ radius=max(5, round(area**0.5 / 120))
68
+ # ์ขŒํ‘œ๊ฐ€ ์ตœ์†Œ 3๊ฐœ ์ด์ƒ์ผ ๋•Œ๋งŒ ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
69
+ draw = ImageDraw.Draw(img)
70
+ for pt in state:
71
+ left_up_point = (pt[0] - radius, pt[1] - radius)
72
+ right_down_point = (pt[0] + radius, pt[1] + radius)
73
+ draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
74
+
75
+ center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
76
+ # ๊ฐ๋„์— ๋”ฐ๋ผ ์ ๋“ค์„ ์ •๋ ฌ
77
+ sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
78
+ draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
79
+
80
+ return img, state
81
+
82
+
83
+ def calculate_angle(point, center):
84
+ return math.atan2(point[1] - center[1], point[0] - center[0])
85
+
86
+ # ์ขŒํ‘œ๋ฅผ ๋ฐ›์•„์„œ ํด๋ฆฌ๊ณค์„ ๊ทธ๋ฆฌ๋Š” ํ•จ์ˆ˜
87
+ def draw_polygon_on_image(image, evt: gr.SelectData, state):
88
+ img = Image.fromarray(image)
89
+
90
+ pt = (evt.index[0], evt.index[1])
91
+ state.append(pt)
92
+ # ํด๋ฆญํ•œ ์ขŒํ‘œ๋ฅผ ์ €์žฅ
93
+ area = image.shape[0]*image.shape[1]
94
+ radius=max(5, round(area**0.5 / 120))
95
+
96
+ draw = ImageDraw.Draw(img)
97
+ for pt in state:
98
+ left_up_point = (pt[0] - radius, pt[1] - radius)
99
+ right_down_point = (pt[0] + radius, pt[1] + radius)
100
+ draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
101
+
102
+ if len(state) == 2:
103
+ draw.line([state[0], state[1]], fill="red", width=round(radius/2))
104
+ if len(state) >= 3: # ์ขŒํ‘œ๊ฐ€ ์ตœ์†Œ 3๊ฐœ ์ด์ƒ์ผ ๋•Œ๋งŒ ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
105
+ center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
106
+ # ๊ฐ๋„์— ๋”ฐ๋ผ ์ ๋“ค์„ ์ •๋ ฌ
107
+ sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
108
+ draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
109
+
110
+ return img, state
111
+
112
+ def sort_corners(corners):
113
+ # ๊ฐ ์ขŒํ‘œ๋ฅผ (x, y) ํ˜•ํƒœ๋กœ ๋ฐ›๋Š”๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
114
+ # corners = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
115
+
116
+ if len(corners) != 4:
117
+ raise ValueError("Input should contain exactly four coordinates.")
118
+
119
+ # ์ขŒํ‘œ๋ฅผ y ๊ธฐ์ค€์œผ๋กœ ์ •๋ ฌํ•˜์—ฌ ๊ฐ ์ขŒํ‘œ๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
120
+ sorted_by_y = sorted(corners, key=lambda p: p[1]) # y ๊ธฐ์ค€์œผ๋กœ ์ •๋ ฌ
121
+ lt, rt = sorted(sorted_by_y[:2], key=lambda p: p[0])
122
+ lb, rb = sorted(sorted_by_y[2:], key=lambda p: p[0])
123
+
124
+ return lt, rt, rb, lb
125
+
126
+ def convert(image, state):
127
+ h,w = image.shape[:2]
128
+ if len(state) < 4:
129
+ out_image, msk_np = docscanner_rec(image, docscanner)
130
+ out_image = out_image[:,:,::-1]
131
+ elif len(state) ==4:
132
+ state = list(sort_corners(state))
133
+ src = np.array(state).astype(np.float32)
134
+ dst = np.float32([
135
+ (0, 0),
136
+ (w - 1, 0),
137
+ (w - 1, h - 1),
138
+ (0, h - 1)
139
+ ])
140
+ out_image, M = unwarp(image, src, dst)
141
+ return out_image
142
+ css = """
143
+ .image-container {
144
+ padding: 20px;
145
+ background-color: #f0f0f0;
146
+ }
147
+ """
148
+ # Gradio Blocks ์ปจํ…์ŠคํŠธ์—์„œ ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
149
+ with gr.Blocks(css=css) as demo:
150
+ state = gr.State([])
151
+ with gr.Row():
152
+ with gr.Column():
153
+ text = gr.Textbox("์ž…๋ ฅ ์ด๋ฏธ์ง€(์ฝ”๋„ˆ๋ฅผ ํด๋ฆญํ•˜์„ธ์š”)", show_label=False)
154
+ image_input = gr.Image(show_label=False, interactive=True, elem_classes="image-container")
155
+ clear_button = gr.Button("Clear Points")
156
+ cutting_button = gr.Button("Cutting Image(need more than 2 points)")
157
+ rotating_button = gr.Button("Rotate Image(clock wise 90 degree)")
158
+ auto_button = gr.Button("Auto Points detection")
159
+ convert_button = gr.Button("Convert Image")
160
+ with gr.Column():
161
+ text = gr.Textbox("๋ณ€ํ™˜๋  ์˜์—ญ", show_label=False)
162
+ image_output = gr.Image(show_label=False)
163
+ # state_display = gr.Textbox(label="Current State")
164
+ # coordinates_text = gr.Textbox(label="Coordinates", placeholder="Enter coordinates (x, y) for each point")
165
+ # update_coords_button = gr.Button("Update Coordinates")
166
+ with gr.Column():
167
+ text = gr.Textbox("๊ฒฐ๊ณผ ์ด๋ฏธ์ง€", show_label=False)
168
+ result_image = gr.Image(show_label=False, format="png")
169
+
170
+ # # ์ด๋ฏธ์ง€ ์œ„์—์„œ ํด๋ฆญ ์ด๋ฒคํŠธ ์ฒ˜๋ฆฌ
171
+ image_input.select(draw_polygon_on_image, inputs=[image_input,state], outputs=[image_output,state])
172
+
173
+
174
+
175
+ # ์ขŒํ‘œ ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ขŒํ‘œ ๋ฆฌ์…‹
176
+ clear_button.click(fn=reset_points, inputs=[image_input,state], outputs=[image_output,state])
177
+ # ์ด๋ฏธ์ง€ ์ž๋ฅด๊ธฐ ํŽธ์ง‘
178
+ cutting_button.click(fn=cutting_image, inputs=[image_input,state], outputs=[image_input, image_output, state])
179
+ # ์ด๋ฏธ์ง€ ํšŒ์ „
180
+ rotating_button.click(fn=rotate_image, inputs=[image_input], outputs=[image_input, state])
181
+ # ์ž๋™ ๊ฒ€์ถœ ๋ฒ„ํŠผ
182
+ auto_button.click(fn=auto_point_detect, inputs=image_input, outputs=[image_output,state])
183
+ # ๋ณ€ํ™˜ ๋ฒ„ํŠผ
184
+ convert_button.click(fn=convert, inputs=[image_input,state], outputs=result_image)
185
+
186
+ demo.launch(share=True)