gtano commited on
Commit
db9dca3
·
1 Parent(s): 3ed2067

First commit

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv
2
+ __pycache__
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+ import folium
6
+ import streamlit as st
7
+ from loguru import logger as log
8
+ from matplotlib import colors
9
+ from streamlit_folium import st_folium
10
+ from utils.folium import (
11
+ draw_gdf,
12
+ get_clean_rendering_container,
13
+ get_map_crop,
14
+ get_parquet,
15
+ )
16
+
17
+ COLOR_MAP = {
18
+ "Grassland": "lime",
19
+ "Cereals": "saddlebrown",
20
+ "Other": "white",
21
+ "Flowers": "cyan",
22
+ "Vegetables": "purple",
23
+ "Fruit trees": "darkgreen",
24
+ }
25
+
26
+
27
+ id_class = {(i + 1): k for i, k in enumerate(COLOR_MAP.keys())}
28
+ color_map_hex = {k: colors.to_hex(v) for k, v in COLOR_MAP.items()}
29
+ id_map_hex = {(i + 1): v for i, v in enumerate(color_map_hex.values())}
30
+ color_map_hex_r = {v: k for k, v in color_map_hex.items()}
31
+ merged_list = list(itertools.chain(*list(color_map_hex_r.items())))
32
+
33
+
34
+ def style_call(feat):
35
+ di = {
36
+ "fillColor": id_map_hex[feat["properties"]["class"]],
37
+ "fillOpacity": 0.7,
38
+ "color": id_map_hex[feat["properties"]["class"]],
39
+ "border-width": "thin",
40
+ "border-color": "#ffffff",
41
+ "weight": 1.2,
42
+ }
43
+ return di
44
+
45
+
46
+ # Page configs
47
+ st.set_page_config(page_title="AIDA", page_icon="🌐", layout="wide")
48
+
49
+ hide_button_style = """
50
+ <style>
51
+ div.stButton > button:first-child {
52
+ display: none;
53
+ }
54
+ </style>
55
+ """
56
+ st.markdown(hide_button_style, unsafe_allow_html=True)
57
+
58
+ # base paths
59
+ base_path = Path("data")
60
+ inference_path = base_path / "inference"
61
+
62
+ # list examples and predictions
63
+ inference_paths = sorted(list(inference_path.iterdir()))
64
+
65
+ example_names = [p.stem for p in inference_paths]
66
+ inference_dict = {p.stem: p for p in inference_paths}
67
+
68
+ # Create selection menu
69
+ container_predictions = st.sidebar.container(border=True)
70
+ with container_predictions:
71
+ prediction_selectbox = st.selectbox(
72
+ "Select an example",
73
+ options=example_names,
74
+ index=None,
75
+ key="selectbox_pred",
76
+ )
77
+ is_prediction_selected = prediction_selectbox is not None
78
+ if is_prediction_selected:
79
+ try:
80
+ # add loading of the selected parquet prediction
81
+ chosen_tile_path = inference_dict[prediction_selectbox]
82
+ with open(chosen_tile_path, "rb") as f:
83
+ prediction_file = f.read()
84
+ except:
85
+ st.warning("File not found")
86
+ chosen_tile_path = None
87
+ prediction_file = None
88
+ else:
89
+ prediction_file = None
90
+ chosen_tile_path = None
91
+
92
+ container = get_clean_rendering_container(prediction_selectbox)
93
+
94
+
95
+ # Stange Hack to always update the map height I guess
96
+ def change_key():
97
+ st.session_state["key_map"] = str(datetime.now())
98
+
99
+
100
+ height_value = st.sidebar.radio(
101
+ "Map height",
102
+ options=[800, 500],
103
+ horizontal=True,
104
+ key="height_map",
105
+ on_change=change_key,
106
+ )
107
+ container = get_clean_rendering_container(height_value)
108
+
109
+ # draw map
110
+ interactive_map = get_map_crop()
111
+
112
+ if prediction_selectbox is not None:
113
+ # draw prediction
114
+ draw_gdf(interactive_map, chosen_tile_path, "Prediction", style_call, id_class)
115
+
116
+
117
+ with container.form(key="form1", clear_on_submit=True):
118
+ folium.LayerControl().add_to(interactive_map)
119
+ output_map = st_folium(
120
+ interactive_map,
121
+ width=None,
122
+ height=height_value,
123
+ returned_objects=["all_drawings"],
124
+ key=st.session_state.get("key_map", "key_map"),
125
+ )
126
+ submit = st.form_submit_button("Recenter map")
127
+
128
+ # Update messages
129
+
130
+
131
+ for name, path in inference_dict.items():
132
+ if path not in st.session_state.keys():
133
+ log.info(f"Loading parquet {name}")
134
+ st.session_state[path] = get_parquet(path, id_class)
data/inference/France_2022.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce8babda3c49b070f2770d9d08257f5d7f3499fe959e3a31cb17ebb6d996604e
3
+ size 4754525
data/inference/Germany_2021.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6328e97907549ddca90a3b11430fabfed5dbfb16c2febdb96d22f8c367c59ad8
3
+ size 2679827
data/inference/Netherlands_2023.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92fba1cc19cd9b73adabc4a4c15be15a8fa266d240083cfbc07c6aafaa518344
3
+ size 5718056
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.33.0
2
+ pathlib==1.0.1
3
+ geopandas==0.14.2
4
+ matplotlib==3.8.2
5
+ folium==0.15.1
6
+ streamlit-folium==0.18.0
7
+ loguru==0.7.2
8
+ shapely==2.0.3
utils/folium.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import branca
2
+ import folium
3
+ import geopandas as gpd
4
+ import streamlit as st
5
+ from folium.plugins import Draw
6
+
7
+
8
+ def get_map_crop():
9
+ m = folium.Map(
10
+ location=[50, 32],
11
+ zoom_start=4,
12
+ tiles="Esri.WorldImagery",
13
+ attributionControl=False,
14
+ prefer_canvas=True,
15
+ )
16
+ Draw(
17
+ export=False,
18
+ show_geometry_on_click=False,
19
+ draw_options={
20
+ "polyline": False,
21
+ "polygon": False,
22
+ "circle": False,
23
+ "circlemarker": False,
24
+ "marker": False,
25
+ },
26
+ ).add_to(m)
27
+ legend_html = """
28
+ {% macro html(this, kwargs) %}
29
+ <div style="
30
+ position: fixed;
31
+ bottom: 35px;
32
+ left: 30px;
33
+ width: 130px;
34
+ line-height: 3px;
35
+ height: 120px;
36
+ z-index:9999;
37
+ font-size:14px;
38
+ ">
39
+ <p><a style="color:#00ff00;font-size:150%;margin-left:20px;">&FilledSmallSquare;</a>&emsp;Grassland</p>
40
+ <p><a style="color:#8b4513;font-size:150%;margin-left:20px;">&FilledSmallSquare;</a>&emsp;Cereals</p>
41
+ <p><a style="color:#ffffff;font-size:150%;margin-left:20px;">&FilledSmallSquare;</a>&emsp;Other</p>
42
+ <p><a style="color:#00ffff;font-size:150%;margin-left:20px;">&FilledSmallSquare;</a>&emsp;Flowers</p>
43
+ <p><a style="color:#800080;font-size:150%;margin-left:20px;">&FilledSmallSquare;</a>&emsp;Vegetables</p>
44
+ <p><a style="color:#006400;font-size:150%;margin-left:20px;">&FilledSmallSquare;</a>&emsp;Fruit trees</p>
45
+ </div>
46
+ <div style="
47
+ position: fixed;
48
+ bottom: 50px;
49
+ left: 30px;
50
+ width: 130px;
51
+ line-height: 3px;
52
+ height: 120px;
53
+ z-index:9998;
54
+ font-size:14px;
55
+ background-color: #ffffff;
56
+
57
+ opacity: 0.85;
58
+ ">
59
+ </div>
60
+ {% endmacro %}
61
+ """
62
+ legend = branca.element.MacroElement()
63
+ legend._template = branca.element.Template(legend_html)
64
+ m.get_root().add_child(legend)
65
+ return m
66
+
67
+
68
+ def get_parquet(tile_path, id_classes):
69
+ data = gpd.read_parquet(tile_path).to_crs(epsg="4326")
70
+ data["class_name"] = [id_classes[i] for i in data["class"]]
71
+ data["geometry"] = data["geometry"].simplify(0.0001)
72
+ return data
73
+
74
+
75
+ def draw_gdf(
76
+ _map,
77
+ tile_path,
78
+ name,
79
+ _style_call,
80
+ id_classes,
81
+ ):
82
+ if tile_path not in st.session_state.keys():
83
+ tile_gdf = get_parquet(tile_path, id_classes)
84
+ else:
85
+ tile_gdf = st.session_state[tile_path]
86
+ feature_group = folium.FeatureGroup(f"{name} layer")
87
+ tooltip = folium.GeoJsonTooltip(
88
+ fields=["class_name", "area"],
89
+ aliases=["Crop type: \t", "Area (m²): \t"],
90
+ localize=True,
91
+ sticky=False,
92
+ labels=True,
93
+ style="""
94
+ background-color: #F0EFEF;
95
+ border: 2px solid black;
96
+ border-radius: 3px;
97
+ box-shadow: 3px;
98
+ """,
99
+ max_width=800,
100
+ )
101
+
102
+ folium.GeoJson(tile_gdf, style_function=_style_call, tooltip=tooltip).add_to(
103
+ feature_group
104
+ )
105
+ feature_group.add_to(_map)
106
+ bound = feature_group.get_bounds()
107
+ _map.fit_bounds(bound)
108
+
109
+
110
+ def get_clean_rendering_container(app_state: str):
111
+ """Makes sure we can render from a clean slate on state changes."""
112
+ slot_in_use = st.session_state.slot_in_use = st.session_state.get(
113
+ "slot_in_use", "a"
114
+ )
115
+ if app_state != st.session_state.get("previous_state", app_state):
116
+ if slot_in_use == "a":
117
+ slot_in_use = st.session_state.slot_in_use = "b"
118
+ else:
119
+ slot_in_use = st.session_state.slot_in_use = "a"
120
+
121
+ st.session_state.previous_state = app_state
122
+
123
+ slot = {
124
+ "a": st.empty(),
125
+ "b": st.empty(),
126
+ }[slot_in_use]
127
+
128
+ return slot.container()