stevenbucaille commited on
Commit
d00bf12
·
1 Parent(s): 626932d

add app.py and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +263 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ import matplotlib.cm as cm
5
+ import numpy as np
6
+ import plotly.graph_objects as go
7
+ import spaces
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import AutoImageProcessor, AutoModelForKeypointMatching
11
+ from transformers.image_utils import to_numpy_array
12
+
13
+
14
+ @spaces.GPU
15
+ def process_images(image1, image2, model_name):
16
+ """
17
+ Process two images and return a plot of the matching keypoints.
18
+ """
19
+ if image1 is None or image2 is None:
20
+ return None
21
+
22
+ images = [image1, image2]
23
+
24
+ processor = AutoImageProcessor.from_pretrained(model_name)
25
+ model = AutoModelForKeypointMatching.from_pretrained(model_name, device_map="auto")
26
+ inputs = processor(images, return_tensors="pt")
27
+ inputs = inputs.to(model.device)
28
+ print(
29
+ f"Model {model_name} is on device: {model.device} and inputs are on device: {inputs['pixel_values'].device}"
30
+ )
31
+
32
+ with torch.no_grad():
33
+ outputs = model(**inputs)
34
+
35
+ image_sizes = [[(image.height, image.width) for image in images]]
36
+ outputs = processor.post_process_keypoint_matching(
37
+ outputs, image_sizes, threshold=0.2
38
+ )
39
+ output = outputs[0]
40
+
41
+ image1 = to_numpy_array(image1)
42
+ image2 = to_numpy_array(image2)
43
+
44
+ height0, width0 = image1.shape[:2]
45
+ height1, width1 = image2.shape[:2]
46
+
47
+ # Create PIL image from numpy array
48
+ pil_img = Image.fromarray((image1 / 255.0 * 255).astype(np.uint8))
49
+ pil_img2 = Image.fromarray((image2 / 255.0 * 255).astype(np.uint8))
50
+
51
+ fig = go.Figure()
52
+
53
+ # Create colormap (red-yellow-green: red for low scores, green for high scores)
54
+ colormap = cm.RdYlGn
55
+
56
+ # Get keypoints
57
+ keypoints0_x, keypoints0_y = output["keypoints0"].unbind(1)
58
+ keypoints1_x, keypoints1_y = output["keypoints1"].unbind(1)
59
+
60
+ # Add a separate trace for each match (line + markers) to enable highlighting
61
+ for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
62
+ keypoints0_x,
63
+ keypoints0_y,
64
+ keypoints1_x,
65
+ keypoints1_y,
66
+ output["matching_scores"],
67
+ ):
68
+ color_val = matching_score.item()
69
+ rgba_color = colormap(color_val)
70
+
71
+ # Convert to rgba string with transparency
72
+ color = f"rgba({int(rgba_color[0] * 255)}, {int(rgba_color[1] * 255)}, {int(rgba_color[2] * 255)}, 0.8)"
73
+
74
+ hover_text = (
75
+ f"Score: {matching_score.item():.3f}<br>"
76
+ f"Point 1: ({keypoint0_x.item():.1f}, {keypoint0_y.item():.1f})<br>"
77
+ f"Point 2: ({keypoint1_x.item():.1f}, {keypoint1_y.item():.1f})"
78
+ )
79
+
80
+ fig.add_trace(
81
+ go.Scatter(
82
+ x=[keypoint0_x.item(), keypoint1_x.item() + width0],
83
+ y=[keypoint0_y.item(), keypoint1_y.item()],
84
+ mode="lines+markers",
85
+ line=dict(color=color, width=2),
86
+ marker=dict(color=color, size=5, opacity=0.8),
87
+ hoverinfo="text",
88
+ hovertext=hover_text,
89
+ showlegend=False,
90
+ )
91
+ )
92
+
93
+ # Update layout to use images as background
94
+ fig.update_layout(
95
+ xaxis=dict(
96
+ range=[0, width0 + width1],
97
+ showgrid=False,
98
+ zeroline=False,
99
+ showticklabels=False,
100
+ ),
101
+ yaxis=dict(
102
+ range=[max(height0, height1), 0],
103
+ showgrid=False,
104
+ zeroline=False,
105
+ showticklabels=False,
106
+ scaleanchor="x",
107
+ scaleratio=1,
108
+ ),
109
+ margin=dict(l=0, r=0, t=0, b=0),
110
+ autosize=True,
111
+ images=[
112
+ dict(
113
+ source=pil_img,
114
+ xref="x",
115
+ yref="y",
116
+ x=0,
117
+ y=0,
118
+ sizex=width0,
119
+ sizey=height0,
120
+ sizing="stretch",
121
+ opacity=1,
122
+ layer="below",
123
+ ),
124
+ dict(
125
+ source=pil_img2,
126
+ xref="x",
127
+ yref="y",
128
+ x=width0,
129
+ y=0,
130
+ sizex=width1,
131
+ sizey=height1,
132
+ sizing="stretch",
133
+ opacity=1,
134
+ layer="below",
135
+ ),
136
+ ],
137
+ )
138
+
139
+ return fig
140
+
141
+
142
+ # Create the Gradio interface
143
+ with gr.Blocks(title="EfficientLoFTR Matching Demo") as demo:
144
+ gr.Markdown("# EfficientLoFTR Matching Demo")
145
+ gr.Markdown(
146
+ "Upload two images and get a side-by-side matching of your images using EfficientLoFTR."
147
+ )
148
+ gr.Markdown("""
149
+ ## How to use:
150
+ 1. Select an EfficientLoFTR model (Original EfficientLoFTR or MatchAnything)
151
+ 2. Upload two images using the file uploaders below
152
+ 3. Click the 'Match Images' button
153
+ 4. View the matched output image below. Higher scores are green, lower scores are red.
154
+
155
+ The app will create a side-by-side matching of your images using EfficientLoFTR.
156
+ You can also select an example image pair from the dataset below.
157
+ """)
158
+
159
+ with gr.Row():
160
+ # Detector choice selector
161
+ detector_choice = gr.Radio(
162
+ choices=[("Original EfficientLoFTR", "zju-community/efficientloftr"), ("MatchAnything", "zju-community/matchanything_eloftr")],
163
+ value="Original EfficientLoFTR",
164
+ label="EfficientLoFTR Model",
165
+ info="Choose between original EfficientLoFTR or MatchAnything"
166
+ )
167
+
168
+ with gr.Row():
169
+ # Input images on the same row
170
+ image1 = gr.Image(label="First Image", type="pil")
171
+ image2 = gr.Image(label="Second Image", type="pil")
172
+
173
+ # Process button
174
+ process_btn = gr.Button("Match Images", variant="primary")
175
+
176
+ # Output plot
177
+ output_plot = gr.Plot(label="Matching Results", scale=2)
178
+
179
+ # Connect the function
180
+ process_btn.click(fn=process_images, inputs=[image1, image2, detector_choice], outputs=[output_plot])
181
+
182
+ # Add some example usage
183
+
184
+ examples = gr.Dataset(
185
+ components=[image1, image2],
186
+ label="Example Image Pairs",
187
+ samples=[
188
+ [
189
+ "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg",
190
+ "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg",
191
+ ],
192
+ [
193
+ "https://raw.githubusercontent.com/cvg/LightGlue/refs/heads/main/assets/DSC_0410.JPG",
194
+ "https://raw.githubusercontent.com/cvg/LightGlue/refs/heads/main/assets/DSC_0411.JPG",
195
+ ],
196
+ [
197
+ "https://raw.githubusercontent.com/cvg/LightGlue/refs/heads/main/assets/sacre_coeur1.jpg",
198
+ "https://raw.githubusercontent.com/cvg/LightGlue/refs/heads/main/assets/sacre_coeur2.jpg",
199
+ ],
200
+ [
201
+ "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg",
202
+ "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/piazza_san_marco_58751010_4849458397.jpg",
203
+ ],
204
+ [
205
+ "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg",
206
+ "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg",
207
+ ],
208
+ # MatchAnything multi-modality pairs
209
+ [
210
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/MTV_thermal_vis_1.jpg",
211
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/MTV_thermal_vis_2.jpg",
212
+ ],
213
+ [
214
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/MTV_thermal_vis_pair2_1.jpg",
215
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/MTV_thermal_vis_pair2_2.jpg",
216
+ ],
217
+ [
218
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/ct_mr_1.png",
219
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/ct_mr_2.png",
220
+ ],
221
+ [
222
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/mri_ut_1.jpg",
223
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/mri_ut_2.jpg",
224
+ ],
225
+ [
226
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/rgb_2.png",
227
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/robot_real_world_2.png",
228
+ ],
229
+ [
230
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/robot_render_1.png",
231
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/robot_real_world_2.png",
232
+ ],
233
+ [
234
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/thermal_1.jpg",
235
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/thermal_vis_1.jpg",
236
+ ],
237
+ [
238
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/thermal_vis_1.jpg",
239
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/thermal_vis_2.jpg",
240
+ ],
241
+ [
242
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_event_1.png",
243
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_event_2.png",
244
+ ],
245
+ [
246
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_map_1.jpg",
247
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_map_2.jpg",
248
+ ],
249
+ [
250
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_map_pair2_1.jpg",
251
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_map_pair2_2.jpg",
252
+ ],
253
+ [
254
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_thermal_ground_1.png",
255
+ "https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_thermal_ground_2.png",
256
+ ],
257
+ ],
258
+ )
259
+
260
+ examples.select(lambda x: (x[0], x[1]), [examples], [image1, image2])
261
+
262
+ if __name__ == "__main__":
263
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.34.2
2
+ Pillow>=10.0.0
3
+ numpy>=1.24.0
4
+ transformers @ git+https://github.com/huggingface/transformers.git@52aaa3f5004d18ecb148c82534eb9eec8ac20f8f
5
+ matplotlib
6
+ torch
7
+ plotly
8
+ spaces
9
+ accelerate
10
+ kornia