Spaces:
Running
Running
import itertools | |
from datetime import datetime | |
from pathlib import Path | |
import folium | |
import streamlit as st | |
from loguru import logger as log | |
from matplotlib import colors | |
from streamlit_folium import st_folium | |
from utils.folium import ( | |
draw_gdf, | |
get_clean_rendering_container, | |
get_map_crop, | |
get_parquet, | |
) | |
COLOR_MAP = { | |
"Grassland": "lime", | |
"Cereals": "saddlebrown", | |
"Other": "white", | |
"Flowers": "cyan", | |
"Vegetables": "purple", | |
"Fruit trees": "darkgreen", | |
} | |
id_class = {(i + 1): k for i, k in enumerate(COLOR_MAP.keys())} | |
color_map_hex = {k: colors.to_hex(v) for k, v in COLOR_MAP.items()} | |
id_map_hex = {(i + 1): v for i, v in enumerate(color_map_hex.values())} | |
color_map_hex_r = {v: k for k, v in color_map_hex.items()} | |
merged_list = list(itertools.chain(*list(color_map_hex_r.items()))) | |
def style_call(feat): | |
di = { | |
"fillColor": id_map_hex[feat["properties"]["class"]], | |
"fillOpacity": 0.7, | |
"color": id_map_hex[feat["properties"]["class"]], | |
"border-width": "thin", | |
"border-color": "#ffffff", | |
"weight": 1.2, | |
} | |
return di | |
# Page configs | |
st.set_page_config(page_title="AIDA", page_icon="๐", layout="wide") | |
# base paths | |
base_path = Path("data") | |
inference_path = base_path / "inference" | |
# list examples and predictions | |
inference_paths = sorted(list(inference_path.iterdir())) | |
example_names = [p.stem for p in inference_paths] | |
inference_dict = {p.stem: p for p in inference_paths} | |
def change_key(): | |
st.session_state["key_map"] = str(datetime.now()) | |
# Create selection menu | |
container_predictions = st.container(border=True) | |
with container_predictions: | |
col1, col2, col3 = st.columns([0.2, 0.1, 0.7]) | |
with col1: | |
prediction_selectbox = st.selectbox( | |
"Select an example", | |
options=example_names, | |
index=None, | |
key="selectbox_pred", | |
) | |
is_prediction_selected = prediction_selectbox is not None | |
if is_prediction_selected: | |
try: | |
# add loading of the selected parquet prediction | |
chosen_tile_path = inference_dict[prediction_selectbox] | |
with open(chosen_tile_path, "rb") as f: | |
prediction_file = f.read() | |
except: | |
st.warning("File not found") | |
chosen_tile_path = None | |
prediction_file = None | |
else: | |
prediction_file = None | |
chosen_tile_path = None | |
with col2: | |
height_value = st.radio( | |
"Map height", | |
options=[800, 500], | |
horizontal=True, | |
key="height_map", | |
on_change=change_key, | |
) | |
with col3: | |
with st.container(): | |
st.write("######") | |
with st.expander("See information about the methodology"): | |
st.write( | |
"""The model uses a modified UTAE architecture, a U-Net variant with attention, | |
for segmenting crop types from satellite imagery. | |
It was trained using Sentinel-2 time series data from several European countries, | |
labeled with aggregated Eurocrops data to identify 6 broad crop classes. | |
More informations about the methodology [HERE](https://huggingface.co/links-ads/aida-cropland-models). | |
Know more how this model was used in practice [HERE](https://business.esa.int/projects/aida)""" | |
) | |
container = get_clean_rendering_container(prediction_selectbox) | |
# Stange Hack to always update the map height I guess | |
container = get_clean_rendering_container(height_value) | |
# draw map | |
interactive_map = get_map_crop() | |
if prediction_selectbox is not None: | |
# draw prediction | |
draw_gdf(interactive_map, chosen_tile_path, "Prediction", style_call, id_class) | |
with container.form(key="form1", clear_on_submit=True): | |
folium.LayerControl().add_to(interactive_map) | |
output_map = st_folium( | |
interactive_map, | |
width=None, | |
height=height_value, | |
returned_objects=["all_drawings"], | |
key=st.session_state.get("key_map", "key_map"), | |
) | |
submit = st.form_submit_button("Recenter map") | |
# Update messages | |
for name, path in inference_dict.items(): | |
if path not in st.session_state.keys(): | |
log.info(f"Loading parquet {name}") | |
st.session_state[path] = get_parquet(path, id_class) | |