Spaces:
Running
on
Zero
Running
on
Zero
import time | |
import gradio as gr | |
import matplotlib.cm as cm | |
import numpy as np | |
import plotly.graph_objects as go | |
import spaces | |
import torch | |
from PIL import Image | |
from transformers import AutoImageProcessor, AutoModelForKeypointMatching | |
from transformers.image_utils import to_numpy_array | |
def process_images(image1, image2, model_name): | |
""" | |
Process two images and return a plot of the matching keypoints. | |
""" | |
if image1 is None or image2 is None: | |
return None | |
images = [image1, image2] | |
processor = AutoImageProcessor.from_pretrained(model_name) | |
model = AutoModelForKeypointMatching.from_pretrained(model_name, device_map="auto") | |
inputs = processor(images, return_tensors="pt") | |
inputs = inputs.to(model.device) | |
print( | |
f"Model {model_name} is on device: {model.device} and inputs are on device: {inputs['pixel_values'].device}" | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
image_sizes = [[(image.height, image.width) for image in images]] | |
outputs = processor.post_process_keypoint_matching( | |
outputs, image_sizes, threshold=0.2 | |
) | |
output = outputs[0] | |
image1 = to_numpy_array(image1) | |
image2 = to_numpy_array(image2) | |
height0, width0 = image1.shape[:2] | |
height1, width1 = image2.shape[:2] | |
# Create PIL image from numpy array | |
pil_img = Image.fromarray((image1 / 255.0 * 255).astype(np.uint8)) | |
pil_img2 = Image.fromarray((image2 / 255.0 * 255).astype(np.uint8)) | |
fig = go.Figure() | |
# Create colormap (red-yellow-green: red for low scores, green for high scores) | |
colormap = cm.RdYlGn | |
# Get keypoints | |
keypoints0_x, keypoints0_y = output["keypoints0"].unbind(1) | |
keypoints1_x, keypoints1_y = output["keypoints1"].unbind(1) | |
# Add a separate trace for each match (line + markers) to enable highlighting | |
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( | |
keypoints0_x, | |
keypoints0_y, | |
keypoints1_x, | |
keypoints1_y, | |
output["matching_scores"], | |
): | |
color_val = matching_score.item() | |
rgba_color = colormap(color_val) | |
# Convert to rgba string with transparency | |
color = f"rgba({int(rgba_color[0] * 255)}, {int(rgba_color[1] * 255)}, {int(rgba_color[2] * 255)}, 0.8)" | |
hover_text = ( | |
f"Score: {matching_score.item():.3f}<br>" | |
f"Point 1: ({keypoint0_x.item():.1f}, {keypoint0_y.item():.1f})<br>" | |
f"Point 2: ({keypoint1_x.item():.1f}, {keypoint1_y.item():.1f})" | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=[keypoint0_x.item(), keypoint1_x.item() + width0], | |
y=[keypoint0_y.item(), keypoint1_y.item()], | |
mode="lines+markers", | |
line=dict(color=color, width=2), | |
marker=dict(color=color, size=5, opacity=0.8), | |
hoverinfo="text", | |
hovertext=hover_text, | |
showlegend=False, | |
) | |
) | |
# Update layout to use images as background | |
fig.update_layout( | |
xaxis=dict( | |
range=[0, width0 + width1], | |
showgrid=False, | |
zeroline=False, | |
showticklabels=False, | |
), | |
yaxis=dict( | |
range=[max(height0, height1), 0], | |
showgrid=False, | |
zeroline=False, | |
showticklabels=False, | |
scaleanchor="x", | |
scaleratio=1, | |
), | |
margin=dict(l=0, r=0, t=0, b=0), | |
autosize=True, | |
images=[ | |
dict( | |
source=pil_img, | |
xref="x", | |
yref="y", | |
x=0, | |
y=0, | |
sizex=width0, | |
sizey=height0, | |
sizing="stretch", | |
opacity=1, | |
layer="below", | |
), | |
dict( | |
source=pil_img2, | |
xref="x", | |
yref="y", | |
x=width0, | |
y=0, | |
sizex=width1, | |
sizey=height1, | |
sizing="stretch", | |
opacity=1, | |
layer="below", | |
), | |
], | |
) | |
return fig | |
# Create the Gradio interface | |
with gr.Blocks(title="EfficientLoFTR Matching Demo") as demo: | |
gr.Markdown("# EfficientLoFTR Matching Demo") | |
gr.Markdown( | |
"Upload two images and get a side-by-side matching of your images using EfficientLoFTR." | |
) | |
gr.Markdown(""" | |
## How to use: | |
1. Select an EfficientLoFTR model (Original EfficientLoFTR or MatchAnything) | |
2. Upload two images using the file uploaders below | |
3. Click the 'Match Images' button | |
4. View the matched output image below. Higher scores are green, lower scores are red. | |
The app will create a side-by-side matching of your images using EfficientLoFTR. | |
You can also select an example image pair from the dataset below. | |
""") | |
with gr.Row(): | |
# Detector choice selector | |
detector_choice = gr.Radio( | |
choices=[("Original EfficientLoFTR", "zju-community/efficientloftr"), ("MatchAnything", "zju-community/matchanything_eloftr")], | |
value="Original EfficientLoFTR", | |
label="EfficientLoFTR Model", | |
info="Choose between original EfficientLoFTR or MatchAnything" | |
) | |
with gr.Row(): | |
# Input images on the same row | |
image1 = gr.Image(label="First Image", type="pil") | |
image2 = gr.Image(label="Second Image", type="pil") | |
# Process button | |
process_btn = gr.Button("Match Images", variant="primary") | |
# Output plot | |
output_plot = gr.Plot(label="Matching Results", scale=2) | |
# Connect the function | |
process_btn.click(fn=process_images, inputs=[image1, image2, detector_choice], outputs=[output_plot]) | |
# Add some example usage | |
examples = gr.Dataset( | |
components=[image1, image2], | |
label="Example Image Pairs", | |
samples_per_page=100, | |
samples=[ | |
[ | |
"https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg", | |
"https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg", | |
], | |
[ | |
"https://raw.githubusercontent.com/cvg/LightGlue/refs/heads/main/assets/DSC_0410.JPG", | |
"https://raw.githubusercontent.com/cvg/LightGlue/refs/heads/main/assets/DSC_0411.JPG", | |
], | |
[ | |
"https://raw.githubusercontent.com/cvg/LightGlue/refs/heads/main/assets/sacre_coeur1.jpg", | |
"https://raw.githubusercontent.com/cvg/LightGlue/refs/heads/main/assets/sacre_coeur2.jpg", | |
], | |
[ | |
"https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg", | |
"https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/piazza_san_marco_58751010_4849458397.jpg", | |
], | |
[ | |
"https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg", | |
"https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg", | |
], | |
# MatchAnything multi-modality pairs | |
[ | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/MTV_thermal_vis_1.jpg", | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/MTV_thermal_vis_2.jpg", | |
], | |
[ | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/MTV_thermal_vis_pair2_1.jpg", | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/MTV_thermal_vis_pair2_2.jpg", | |
], | |
[ | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/ct_mr_1.png", | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/ct_mr_2.png", | |
], | |
[ | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/mri_ut_1.jpg", | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/mri_ut_2.jpg", | |
], | |
[ | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/robot_render_1.png", | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/robot_real_world_2.png", | |
], | |
[ | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/thermal_vis_1.jpg", | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/thermal_vis_2.jpg", | |
], | |
[ | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_event_1.png", | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_event_2.png", | |
], | |
[ | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_map_1.jpg", | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_map_2.jpg", | |
], | |
[ | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_map_pair2_1.jpg", | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_map_pair2_2.jpg", | |
], | |
[ | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_thermal_ground_1.png", | |
"https://huggingface.co/spaces/LittleFrog/MatchAnything/resolve/main/imcui/datasets/multi_modality_pairs/vis_thermal_ground_2.png", | |
], | |
], | |
) | |
examples.select(lambda x: (x[0], x[1]), [examples], [image1, image2]) | |
if __name__ == "__main__": | |
demo.launch() | |