Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import onnxruntime as ort | |
import json | |
from cryptography.fernet import Fernet | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Model load | |
key = os.getenv("ONNX_KEY") | |
cipher = Fernet(key) | |
with open("species_bag.onnx.encrypted", "rb") as f: | |
encrypted = f.read() | |
decrypted = cipher.decrypt(encrypted) | |
ort_session = ort.InferenceSession(decrypted) | |
# Initialize ONNX session | |
input_name = ort_session.get_inputs()[0].name | |
output_name = ort_session.get_outputs()[0].name | |
# Load mappings | |
with open("idx2spec.json", "r") as f: | |
idx2spec = json.load(f) | |
with open("spec2idx.json", "r") as f: | |
spec2idx = json.load(f) | |
with open("spec2key.json", "r") as f: | |
spec2key = json.load(f) | |
baseurl = "https://www.gbif.org/species/" | |
def predict_species(selected_species, n_hits=10): | |
if not selected_species: | |
return "", "" | |
# Convert species names to indices using spec2idx | |
input_indices = [int(spec2idx[name]) for name in selected_species] | |
# Model inference | |
input_np = np.array(input_indices, dtype=np.int64).reshape(1, -1) | |
output = ort_session.run([output_name], {input_name: input_np})[0][0] | |
# Get top predictions | |
top_indices = output.argsort()[-n_hits:][::-1] | |
top_scores = output[top_indices] | |
# Format selected species with links | |
selected_html = ["<a href='{}{}' target='_blank'>{}</a>".format( | |
baseurl, spec2key[species], species | |
) for species in selected_species] | |
# Format predictions with species names, links and scores | |
predictions_html = ["<a href='{}{}' target='_blank'>{}</a> ({:.1f}%)".format( | |
baseurl, spec2key[idx2spec[str(idx)]], idx2spec[str(idx)], 100*score | |
) for idx, score in zip(top_indices, top_scores)] | |
return "<br>".join(selected_html), "<br>".join(predictions_html) | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Danmarks planter - hvem mangler?") | |
gr.Markdown("*Sammensæt et plantesamfund og få forslag til andre arter der passer ind. Vælg mellem 3199 danske terrestriske og akvatiske planter.*") | |
with gr.Row(): | |
species_dropdown = gr.Dropdown( | |
choices=sorted(spec2idx.keys()), | |
multiselect=True, | |
label="Find arter", | |
) | |
with gr.Row(): | |
with gr.Column(scale=5, min_width=200): | |
selected_output = gr.HTML( | |
label="Arter", | |
show_label=True | |
) | |
with gr.Column(scale=5, min_width=200): | |
predictions_output = gr.HTML( | |
label="Top hits", | |
show_label=True | |
) | |
with gr.Column(scale=1, min_width=100): | |
n_hits = gr.Number(10, label="Antal hits", minimum=1, maximum=100) | |
add_button = gr.Button("Tilføj top hit", scale=8) | |
gr.Markdown("Forslag er baseret på et neuralt netværk trænet til at forudsige de mest sandsynlige arter som mangler i et plantesamfund. Trænet på stort datasæt af plantesamfund registreret i Danmark (**4.3 millioner registreringer af 3199 arter/slægter/varianter i mere end 180.000 undersøgelser**).") | |
gr.Markdown("App og model af Kenneth Thorø Martinsen.") | |
def add_top_prediction(selected_species): | |
if not selected_species: | |
return selected_species | |
top_prediction = predict_species(selected_species)[1].split("<br>")[0].split(" (")[0] | |
top_prediction = top_prediction.split(">")[1].split("<")[0] # Extract species name from HTML | |
if top_prediction not in selected_species: | |
selected_species.append(top_prediction) | |
return selected_species | |
species_dropdown.change( | |
predict_species, | |
inputs=[species_dropdown, n_hits], | |
outputs=[selected_output, predictions_output] | |
) | |
add_button.click( | |
add_top_prediction, | |
inputs=[species_dropdown], | |
outputs=[species_dropdown] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |