File size: 4,033 Bytes
2b81e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bcb316
2b81e88
 
8bcb316
2b81e88
 
8bcb316
2b81e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd73030
2b81e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd73030
69571a6
2b81e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import numpy as np
import onnxruntime as ort
import json
from cryptography.fernet import Fernet
import os
from dotenv import load_dotenv

load_dotenv()

# Model load
key = os.getenv("ONNX_KEY")
cipher = Fernet(key)

with open("species_bag.onnx.encrypted", "rb") as f:
    encrypted = f.read()
    decrypted = cipher.decrypt(encrypted)
    ort_session = ort.InferenceSession(decrypted)

# Initialize ONNX session
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name

# Load mappings
with open("idx2spec.json", "r") as f:
    idx2spec = json.load(f)

with open("spec2idx.json", "r") as f:
    spec2idx = json.load(f)

with open("spec2key.json", "r") as f:
    spec2key = json.load(f)

baseurl = "https://www.gbif.org/species/"

def predict_species(selected_species, n_hits=10):
    if not selected_species:
        return "", ""
    
    # Convert species names to indices using spec2idx
    input_indices = [int(spec2idx[name]) for name in selected_species]
    
    # Model inference
    input_np = np.array(input_indices, dtype=np.int64).reshape(1, -1)
    output = ort_session.run([output_name], {input_name: input_np})[0][0]
    
    # Get top predictions
    top_indices = output.argsort()[-n_hits:][::-1]
    top_scores = output[top_indices]
    
    # Format selected species with links
    selected_html = ["<a href='{}{}' target='_blank'>{}</a>".format(
        baseurl, spec2key[species], species
    ) for species in selected_species]
    
    # Format predictions with species names, links and scores
    predictions_html = ["<a href='{}{}' target='_blank'>{}</a> ({:.1f}%)".format(
        baseurl, spec2key[idx2spec[str(idx)]], idx2spec[str(idx)], 100*score
    ) for idx, score in zip(top_indices, top_scores)]
    
    return "<br>".join(selected_html), "<br>".join(predictions_html)

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Danmarks planter - hvem mangler?")
    gr.Markdown("*Sammensæt et plantesamfund og få forslag til andre arter der passer ind. Vælg mellem 3199 danske terrestriske og akvatiske planter.*")

    with gr.Row():
        species_dropdown = gr.Dropdown(
            choices=sorted(spec2idx.keys()),
            multiselect=True,
            label="Find arter",
        )
    
    with gr.Row():
        with gr.Column(scale=5, min_width=200):
            selected_output = gr.HTML(
                label="Arter",
                show_label=True
            )
        with gr.Column(scale=5, min_width=200):
            
            predictions_output = gr.HTML(
                label="Top hits",
                show_label=True
            )
        with gr.Column(scale=1, min_width=100):
            n_hits = gr.Number(10, label="Antal hits", minimum=1, maximum=100)
            add_button = gr.Button("Tilføj top hit", scale=8)

    gr.Markdown("Forslag er baseret på et neuralt netværk trænet til at forudsige de mest sandsynlige arter som mangler i et plantesamfund. Trænet på stort datasæt af plantesamfund registreret i Danmark (**4.3 millioner registreringer af 3199 arter/slægter/varianter i mere end 180.000 undersøgelser**).")
    gr.Markdown("App og model af Kenneth Thorø Martinsen (kenneth2810@gmail.com).")

    def add_top_prediction(selected_species):
        if not selected_species:
            return selected_species
        top_prediction = predict_species(selected_species)[1].split("<br>")[0].split(" (")[0]
        top_prediction = top_prediction.split(">")[1].split("<")[0]  # Extract species name from HTML
        if top_prediction not in selected_species:
            selected_species.append(top_prediction)
        return selected_species

    species_dropdown.change(
        predict_species,
        inputs=[species_dropdown, n_hits],
        outputs=[selected_output, predictions_output]
    )

    add_button.click(
        add_top_prediction,
        inputs=[species_dropdown],
        outputs=[species_dropdown]
    )

if __name__ == "__main__":
    demo.launch()