Spaces:
Running
Running
File size: 4,401 Bytes
c4e96c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import itertools
from datetime import datetime
from pathlib import Path
import folium
import streamlit as st
from loguru import logger as log
from matplotlib import colors
from streamlit_folium import st_folium
from utils.folium import (
draw_gdf,
get_clean_rendering_container,
get_map_crop,
get_parquet,
)
COLOR_MAP = {
"Grassland": "lime",
"Cereals": "saddlebrown",
"Other": "white",
"Flowers": "cyan",
"Vegetables": "purple",
"Fruit trees": "darkgreen",
}
id_class = {(i + 1): k for i, k in enumerate(COLOR_MAP.keys())}
color_map_hex = {k: colors.to_hex(v) for k, v in COLOR_MAP.items()}
id_map_hex = {(i + 1): v for i, v in enumerate(color_map_hex.values())}
color_map_hex_r = {v: k for k, v in color_map_hex.items()}
merged_list = list(itertools.chain(*list(color_map_hex_r.items())))
def style_call(feat):
di = {
"fillColor": id_map_hex[feat["properties"]["class"]],
"fillOpacity": 0.7,
"color": id_map_hex[feat["properties"]["class"]],
"border-width": "thin",
"border-color": "#ffffff",
"weight": 1.2,
}
return di
# Page configs
st.set_page_config(page_title="AIDA", page_icon="🌐", layout="wide")
# base paths
base_path = Path("data")
inference_path = base_path / "inference"
# list examples and predictions
inference_paths = sorted(list(inference_path.iterdir()))
example_names = [p.stem for p in inference_paths]
inference_dict = {p.stem: p for p in inference_paths}
def change_key():
st.session_state["key_map"] = str(datetime.now())
# Create selection menu
container_predictions = st.container(border=True)
with container_predictions:
col1, col2, col3 = st.columns([0.2, 0.1, 0.7])
with col1:
prediction_selectbox = st.selectbox(
"Select an example",
options=example_names,
index=None,
key="selectbox_pred",
)
is_prediction_selected = prediction_selectbox is not None
if is_prediction_selected:
try:
# add loading of the selected parquet prediction
chosen_tile_path = inference_dict[prediction_selectbox]
with open(chosen_tile_path, "rb") as f:
prediction_file = f.read()
except:
st.warning("File not found")
chosen_tile_path = None
prediction_file = None
else:
prediction_file = None
chosen_tile_path = None
with col2:
height_value = st.radio(
"Map height",
options=[800, 500],
horizontal=True,
key="height_map",
on_change=change_key,
)
with col3:
with st.container():
st.write("######")
with st.expander("See information about the methodology"):
st.write(
"""The model uses a modified UTAE architecture, a U-Net variant with attention,
for segmenting crop types from satellite imagery.
It was trained using Sentinel-2 time series data from several European countries,
labeled with aggregated Eurocrops data to identify 6 broad crop classes.
More informations about the methodology [HERE](https://huggingface.co/links-ads/aida-cropland-models).
Know more how this model was used in practice [HERE](https://business.esa.int/projects/aida)"""
)
container = get_clean_rendering_container(prediction_selectbox)
# Stange Hack to always update the map height I guess
container = get_clean_rendering_container(height_value)
# draw map
interactive_map = get_map_crop()
if prediction_selectbox is not None:
# draw prediction
draw_gdf(interactive_map, chosen_tile_path, "Prediction", style_call, id_class)
with container.form(key="form1", clear_on_submit=True):
folium.LayerControl().add_to(interactive_map)
output_map = st_folium(
interactive_map,
width=None,
height=height_value,
returned_objects=["all_drawings"],
key=st.session_state.get("key_map", "key_map"),
)
submit = st.form_submit_button("Recenter map")
# Update messages
for name, path in inference_dict.items():
if path not in st.session_state.keys():
log.info(f"Loading parquet {name}")
st.session_state[path] = get_parquet(path, id_class)
|