Spaces:
Running
Running
from shiny import App, ui, render, reactive | |
from shiny.ui import HTML, tags | |
import shinyswatch | |
import torch | |
import pandas as pd | |
import numpy as np | |
import io | |
import torch.nn.functional as F | |
from utils import load_training_data, load_models | |
MONTHS = { | |
0: "Jan", | |
1: "Feb", | |
2: "Mar", | |
3: "Apr", | |
4: "May", | |
5: "Jun", | |
6: "Jul", | |
7: "Aug", | |
8: "Sep", | |
9: "Oct", | |
10: "Nov", | |
11: "Dec", | |
} | |
YEARS = list(range(2000, 2015)) | |
RESOLUTIONS = { | |
"0": "Local", | |
"1": "32 km", | |
"3": "96 km", | |
"5": "160 km", | |
"7": "224 km", | |
"9": "288 km", | |
} | |
WCOLS = { | |
"air.2m.mon.mean.nc": "temperature at 2m", | |
# "air.sfc.mon.mean.nc": "surface temperature", | |
"apcp.mon.mean.nc": "total precipitation", | |
# "acpcp.mon.mean.nc": "acc. convective precip", | |
# "tcdc.mon.mean.nc": "total cloud cover", | |
# "dswrf.mon.mean.nc": "down short rads flux", | |
# "hpbl.mon.mean.nc": "planet boundary layer height", | |
"rhum.2m.mon.mean.nc": "relative humidity", | |
"vwnd.10m.mon.mean.nc": "(north-south) wind component", | |
"uwnd.10m.mon.mean.nc": "(east-west) wind component", | |
} | |
# RESOLUTION CONSTANTS | |
NROW = 128 | |
NCOL = 256 | |
XMIN = -135.0 | |
XMAX = -60.0 | |
YMIN = 20.0 | |
YMAX = 52.0 | |
DLON = (XMAX - XMIN) / NCOL | |
DLAT = (YMIN - YMAX) / NROW | |
# Load non-reactivelye | |
C, NAMES, Y, M = load_training_data( | |
path="data/training_data.pkl", | |
standardize_so4=True, | |
log_so4=True, | |
year_averages=True, | |
) | |
ND = C.shape[1] | |
_, _, YRAW, MRAW = load_training_data(path="data/training_data.pkl") | |
DIRS = { | |
"1": f"./data/weights/h1_w2vec", | |
"3": f"./data/weights/h3_w2vec", | |
"5": f"./data/weights/h5_w2vec", | |
"7": f"./data/weights/h7_w2vec", | |
"9": f"./data/weights/h9_w2vec", | |
} | |
MODELS = load_models(DIRS, prefix="h", nd=ND) | |
multicol_html = tags.head( | |
tags.style( | |
HTML( | |
".multicol {" | |
# "height: 150px; " | |
"-webkit-column-count: 3;" # chrome, safari, opera | |
"-moz-column-count: 3;" # firefox | |
"column-count: 3;" | |
"-moz-column-fill: auto;" | |
"-column-fill: auto;" | |
) | |
) | |
) | |
instructions = f""" | |
### Instructions | |
Upload a CSV file with columns (id, lat, lon) using the `Browse` button on the sidebar. | |
Below is an example of the contents of the file: | |
``` | |
id,lat,lon | |
0,47.5,-122.5 | |
1,47.5,-122.25 | |
2,47.5,-122.0 | |
3,47.5,-121.75 | |
4,47.5,-121.5 | |
``` | |
The id column can be any identifier, or the column can be ommited, in which case the row number will be used as the id. | |
Make sure that the latitude is before the longitude column in the CSV file. The valid range for latitude is | |
{YMIN} to {YMAX} and longitude is {XMIN} to {XMAX}, which cover the contiguous United States. | |
The resolution corresponds to how much neighboring information is captured by the embedding. If `local` is selected, | |
the original weather covariates will be returned. Currently, all the embeddings correspond to the variables: | |
* `air.2m.mon.mean.nc`: temperature at 2m | |
* `apcp.mon.mean.nc`: total precipitation | |
* `rhum.2m.mon.mean.nc`: relative humidity | |
* `vwnd.10m.mon.mean.nc`: (north-south) wind component | |
* `uwnd.10m.mon.mean.nc`: (east-west) wind component | |
The radius corresponds to the number of neighboring raster cells to include in weather2vec representation. A resolution of 96km means that the embeddings encodes informations from all nearby raster cells whose centers are less than 96km. All embeddings have 10 hidden dimensions. | |
The embeddings also record information of the 12-month moving average. For this reason, the 'local' embeddings also have dimension 10, the first 5 dimensions correspond to the 5 meteorological variables in a given month, and the last 5 dimensions correspond to their 12-month moving average. For the non-local embeddings, the order of the variables is not interpretable. | |
### Download | |
""" | |
citation = """ | |
### Citation | |
Tec, M., Scott, J.G. and Zigler, C.M., 2023. "Weather2vec: Representation learning for causal inference with non-local confounding in air pollution and climate studies". In: *Proceedings of the AAAI Conference on Artificial Intelligence*. | |
``` | |
@inproceedings{tec2023weather2vec, | |
title={Weather2vec: Representation learning for causal inference with non-local confounding in air pollution and climate studies}, | |
author={Tec, Mauricio and Scott, James G and Zigler, Corwin M}, | |
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, | |
volume={37}, | |
number={12}, | |
pages={14504--14513}, | |
year={2023} | |
} | |
``` | |
""" | |
# After uploading the file, the app will generate a CSV, a download link will appear here. | |
# The CSV will contain the following columns: | |
# Part 1: ui ---- | |
app_ui = ui.page_fluid( | |
shinyswatch.theme.minty(), | |
multicol_html, | |
ui.panel_title("Welcome to the Weather2vec Embedding Generator!"), | |
ui.layout_sidebar( | |
ui.panel_sidebar( | |
ui.input_file("df", "Upload CSV File", accept=".csv"), | |
tags.div( | |
ui.input_checkbox_group("months", HTML("<b>Months</b>"), MONTHS, selected=list(MONTHS.keys())), | |
class_="multicol", | |
align="left", | |
inline=False, | |
), | |
HTML( | |
"<b>Note:</b> Embedding of multiple months will be added.<br>True multi-temporal embeddings will be supported in the future.<br><br>" | |
), | |
tags.div( | |
ui.input_radio_buttons("year", HTML("<b>Year</b>"), YEARS), | |
class_="multicol", | |
align="left", | |
inline=False, | |
), | |
HTML("<br>"), | |
tags.div( | |
ui.input_radio_buttons( | |
"resolution", HTML("<b>Resolution</b>"), RESOLUTIONS, selected="9" | |
), | |
class_="multicol", | |
align="left", | |
inline=False, | |
), | |
HTML("<br>"), | |
ui.download_link("download_test", "Download an example input file here."), | |
HTML("<br><b>Note</b>There are some issues with scrolling using Safari, try a different browser please."), | |
width=4, | |
), | |
ui.panel_main( | |
ui.markdown(instructions), | |
ui.output_ui("download_ui"), | |
ui.markdown(citation), | |
), | |
), | |
) | |
# Part 2: server ---- | |
def server(input, output, session): | |
def download_ui(): | |
if input.df() is None: | |
return HTML("<font color=red>Upload a CSV file first. A download button will appear here.</font>") | |
else: | |
return ui.div( | |
ui.download_button("download", "Download Embeddings"), | |
ui.output_data_frame("embs_preview"), | |
) | |
def embs_preview(): | |
df_embs_ = df_embs() | |
if df_embs_ is None: | |
return None | |
else: | |
return df_embs_.reset_index().head() | |
def df_embs(): | |
if input.df() is None: | |
return None | |
# read input file | |
print(input.df()[-1].keys()) | |
fname = input.df()[-1]["datapath"] | |
df = pd.read_csv(fname) | |
if df.shape[1] > 2: | |
first_col = df.columns[0] | |
df = df.set_index(first_col) | |
months = np.array(input.months(), dtype=int) | |
year = int(input.year()) | |
if len(months) == 0: | |
raise ValueError("Must select at least one month.") | |
# obtain temporal indices | |
idxs = (year - 2000) * 12 + months - 1 | |
Ct = torch.FloatTensor(C)[idxs] | |
# compute row, col from lat, lon | |
lat = df.values[:, -2] | |
lon = df.values[:, -1] | |
# | |
interp_factor = 32 | |
dlon_ = DLON / interp_factor | |
dlat_ = DLAT / interp_factor | |
col = (lon - XMIN) // dlon_ | |
row = (lat - YMAX) // dlat_ | |
# get model from resolution | |
resolution = input.resolution() | |
if resolution == "0": | |
Z = Ct.mean(0) | |
else: | |
key = DIRS[resolution] | |
mod = MODELS[key]["mod"] | |
# evaluate model on input locations | |
with torch.no_grad(): | |
Z = mod["enc"](Ct).mean(0) | |
# use bilinear interpolation to augment resolution | |
Z = F.interpolate( | |
Z[None], | |
scale_factor=interp_factor, | |
mode="bilinear", | |
align_corners=False, | |
) | |
# get embedding at input locations | |
Z = Z[0, :, row, col].squeeze(0).squeeze(0).numpy().T | |
# add to dataframe | |
df_embs = pd.DataFrame(Z, columns=[f"Z{i:02d}" for i in range(Z.shape[1])]) | |
df_embs.index = df.index | |
if df.shape[1] > 2: | |
df_id = df.iloc[:, :-2] | |
df_embs = pd.concat([df_id, df_embs], axis=1) | |
return df_embs | |
def download(): | |
if input.df() is None: | |
raise ValueError("Upload a CSV file first.") | |
with io.BytesIO() as f: | |
df_embs().to_csv(f, index=False) | |
yield f.getvalue() | |
def download_test(): | |
with io.BytesIO() as f: | |
df = pd.read_csv("data/test-data.csv") | |
df.to_csv(f, index=False) | |
yield f.getvalue() | |
# Combine into a shiny app. | |
# Note that the variable must be "app". | |
app = App(app_ui, server) | |