KennethTM's picture
Update app.py
cd73030 verified
import gradio as gr
import numpy as np
import onnxruntime as ort
import json
from cryptography.fernet import Fernet
import os
from dotenv import load_dotenv
load_dotenv()
# Model load
key = os.getenv("ONNX_KEY")
cipher = Fernet(key)
with open("species_bag.onnx.encrypted", "rb") as f:
encrypted = f.read()
decrypted = cipher.decrypt(encrypted)
ort_session = ort.InferenceSession(decrypted)
# Initialize ONNX session
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
# Load mappings
with open("idx2spec.json", "r") as f:
idx2spec = json.load(f)
with open("spec2idx.json", "r") as f:
spec2idx = json.load(f)
with open("spec2key.json", "r") as f:
spec2key = json.load(f)
baseurl = "https://www.gbif.org/species/"
def predict_species(selected_species, n_hits=10):
if not selected_species:
return "", ""
# Convert species names to indices using spec2idx
input_indices = [int(spec2idx[name]) for name in selected_species]
# Model inference
input_np = np.array(input_indices, dtype=np.int64).reshape(1, -1)
output = ort_session.run([output_name], {input_name: input_np})[0][0]
# Get top predictions
top_indices = output.argsort()[-n_hits:][::-1]
top_scores = output[top_indices]
# Format selected species with links
selected_html = ["<a href='{}{}' target='_blank'>{}</a>".format(
baseurl, spec2key[species], species
) for species in selected_species]
# Format predictions with species names, links and scores
predictions_html = ["<a href='{}{}' target='_blank'>{}</a> ({:.1f}%)".format(
baseurl, spec2key[idx2spec[str(idx)]], idx2spec[str(idx)], 100*score
) for idx, score in zip(top_indices, top_scores)]
return "<br>".join(selected_html), "<br>".join(predictions_html)
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Danmarks planter - hvem mangler?")
gr.Markdown("*Sammensæt et plantesamfund og få forslag til andre arter der passer ind. Vælg mellem 3199 danske terrestriske og akvatiske planter.*")
with gr.Row():
species_dropdown = gr.Dropdown(
choices=sorted(spec2idx.keys()),
multiselect=True,
label="Find arter",
)
with gr.Row():
with gr.Column(scale=5, min_width=200):
selected_output = gr.HTML(
label="Arter",
show_label=True
)
with gr.Column(scale=5, min_width=200):
predictions_output = gr.HTML(
label="Top hits",
show_label=True
)
with gr.Column(scale=1, min_width=100):
n_hits = gr.Number(10, label="Antal hits", minimum=1, maximum=100)
add_button = gr.Button("Tilføj top hit", scale=8)
gr.Markdown("Forslag er baseret på et neuralt netværk trænet til at forudsige de mest sandsynlige arter som mangler i et plantesamfund. Trænet på stort datasæt af plantesamfund registreret i Danmark (**4.3 millioner registreringer af 3199 arter/slægter/varianter i mere end 180.000 undersøgelser**).")
gr.Markdown("App og model af Kenneth Thorø Martinsen.")
def add_top_prediction(selected_species):
if not selected_species:
return selected_species
top_prediction = predict_species(selected_species)[1].split("<br>")[0].split(" (")[0]
top_prediction = top_prediction.split(">")[1].split("<")[0] # Extract species name from HTML
if top_prediction not in selected_species:
selected_species.append(top_prediction)
return selected_species
species_dropdown.change(
predict_species,
inputs=[species_dropdown, n_hits],
outputs=[selected_output, predictions_output]
)
add_button.click(
add_top_prediction,
inputs=[species_dropdown],
outputs=[species_dropdown]
)
if __name__ == "__main__":
demo.launch()