gtano's picture
Update app.py
c4e96c8 verified
import itertools
from datetime import datetime
from pathlib import Path
import folium
import streamlit as st
from loguru import logger as log
from matplotlib import colors
from streamlit_folium import st_folium
from utils.folium import (
draw_gdf,
get_clean_rendering_container,
get_map_crop,
get_parquet,
)
COLOR_MAP = {
"Grassland": "lime",
"Cereals": "saddlebrown",
"Other": "white",
"Flowers": "cyan",
"Vegetables": "purple",
"Fruit trees": "darkgreen",
}
id_class = {(i + 1): k for i, k in enumerate(COLOR_MAP.keys())}
color_map_hex = {k: colors.to_hex(v) for k, v in COLOR_MAP.items()}
id_map_hex = {(i + 1): v for i, v in enumerate(color_map_hex.values())}
color_map_hex_r = {v: k for k, v in color_map_hex.items()}
merged_list = list(itertools.chain(*list(color_map_hex_r.items())))
def style_call(feat):
di = {
"fillColor": id_map_hex[feat["properties"]["class"]],
"fillOpacity": 0.7,
"color": id_map_hex[feat["properties"]["class"]],
"border-width": "thin",
"border-color": "#ffffff",
"weight": 1.2,
}
return di
# Page configs
st.set_page_config(page_title="AIDA", page_icon="๐ŸŒ", layout="wide")
# base paths
base_path = Path("data")
inference_path = base_path / "inference"
# list examples and predictions
inference_paths = sorted(list(inference_path.iterdir()))
example_names = [p.stem for p in inference_paths]
inference_dict = {p.stem: p for p in inference_paths}
def change_key():
st.session_state["key_map"] = str(datetime.now())
# Create selection menu
container_predictions = st.container(border=True)
with container_predictions:
col1, col2, col3 = st.columns([0.2, 0.1, 0.7])
with col1:
prediction_selectbox = st.selectbox(
"Select an example",
options=example_names,
index=None,
key="selectbox_pred",
)
is_prediction_selected = prediction_selectbox is not None
if is_prediction_selected:
try:
# add loading of the selected parquet prediction
chosen_tile_path = inference_dict[prediction_selectbox]
with open(chosen_tile_path, "rb") as f:
prediction_file = f.read()
except:
st.warning("File not found")
chosen_tile_path = None
prediction_file = None
else:
prediction_file = None
chosen_tile_path = None
with col2:
height_value = st.radio(
"Map height",
options=[800, 500],
horizontal=True,
key="height_map",
on_change=change_key,
)
with col3:
with st.container():
st.write("######")
with st.expander("See information about the methodology"):
st.write(
"""The model uses a modified UTAE architecture, a U-Net variant with attention,
for segmenting crop types from satellite imagery.
It was trained using Sentinel-2 time series data from several European countries,
labeled with aggregated Eurocrops data to identify 6 broad crop classes.
More informations about the methodology [HERE](https://huggingface.co/links-ads/aida-cropland-models).
Know more how this model was used in practice [HERE](https://business.esa.int/projects/aida)"""
)
container = get_clean_rendering_container(prediction_selectbox)
# Stange Hack to always update the map height I guess
container = get_clean_rendering_container(height_value)
# draw map
interactive_map = get_map_crop()
if prediction_selectbox is not None:
# draw prediction
draw_gdf(interactive_map, chosen_tile_path, "Prediction", style_call, id_class)
with container.form(key="form1", clear_on_submit=True):
folium.LayerControl().add_to(interactive_map)
output_map = st_folium(
interactive_map,
width=None,
height=height_value,
returned_objects=["all_drawings"],
key=st.session_state.get("key_map", "key_map"),
)
submit = st.form_submit_button("Recenter map")
# Update messages
for name, path in inference_dict.items():
if path not in st.session_state.keys():
log.info(f"Loading parquet {name}")
st.session_state[path] = get_parquet(path, id_class)