|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from sklearn.cluster import KMeans |
|
import joblib |
|
|
|
@st.cache_resource |
|
def load_dataset(path: str) -> pd.DataFrame: |
|
""" |
|
Opens up a `.csv` file as our main dataset. |
|
|
|
Args: |
|
path (str): Path to the dataset to be opened |
|
|
|
Returns: |
|
pd.DataFrame: A pandas dataframe |
|
""" |
|
|
|
df = pd.read_csv(path) |
|
|
|
dates = pd.to_datetime(df['Date']) |
|
df['Date'] = dates |
|
|
|
df['Year'] = df['Date'].dt.year |
|
df['Month'] = df['Date'].dt.month |
|
df['Day'] = df['Date'].dt.day |
|
|
|
df.drop('Date', axis=1, inplace=True) |
|
|
|
return df |
|
|
|
@st.cache_resource |
|
def load_model(model_path: str) -> KMeans: |
|
return joblib.load(model_path) |
|
|
|
def process_data(input_X: pd.DataFrame) -> np.array: |
|
""" |
|
Processes user input data into usable form for the KMeans model to predict. |
|
|
|
Args: |
|
input_X (pd.DataFrame): Input data in dataframe format with one instance. |
|
|
|
Returns: |
|
np.array: An numpy array for the KMeans model to predict |
|
""" |
|
|
|
input_X = input_X.copy() |
|
|
|
|
|
if 'Date' in input_X.columns: |
|
dates = pd.to_datetime(input_X['Date']) |
|
input_X['Date'] = dates |
|
|
|
input_X['Year'] = input_X['Date'].dt.year |
|
input_X['Month'] = input_X['Date'].dt.month |
|
input_X['Day'] = input_X['Date'].dt.day |
|
|
|
input_X.drop('Date', axis=1, inplace=True) |
|
|
|
input_X = pd.get_dummies(input_X, prefix=['Source'], dtype=int) |
|
|
|
for col in ['Source_GCAG', 'Source_GISTEMP']: |
|
if col not in input_X.columns: |
|
input_X[col] = 0 |
|
|
|
|
|
input_X = input_X[['Mean', 'Year', 'Month', 'Day', 'Source_GCAG', 'Source_GISTEMP']] |
|
|
|
arr_X = input_X.to_numpy() |
|
|
|
return arr_X |
|
|
|
def plot_clusters(model: KMeans, X: np.array, input_X: np.array) -> None: |
|
""" |
|
Plots the predicted class to the clusters. |
|
|
|
Args: |
|
model (KMeans): A KMeans model trained on X input |
|
X (np.array): The numpy array version of the dataset |
|
input_X (np.array): The numpy array of the input |
|
|
|
Returns: |
|
None |
|
""" |
|
centroids = model.cluster_centers_ |
|
labels = model.labels_ |
|
|
|
fig = plt.figure(figsize=(10,6)) |
|
ax = fig.subplots() |
|
|
|
for cluster in range(3): |
|
cluster_points = X[labels == cluster] |
|
plt.scatter(cluster_points[:, 0], cluster_points[:, 1], label=f'Cluster {cluster}') |
|
|
|
|
|
ax.scatter(centroids[:, 0], centroids[:, 1], s=200, c='black', marker='X', label='Centroids') |
|
|
|
predictions = model.predict(input_X) |
|
|
|
st.write(f"Predicted Cluster: {predictions[0]}") |
|
|
|
|
|
ax.scatter(input_X[:, 0], input_X[:, 1], s=300, c='red', marker='P', label=f'Predicted Cluster: {predictions[0]}') |
|
|
|
ax.set_title('K-Means Clustering with Predicted Point') |
|
ax.legend() |
|
st.pyplot(fig) |
|
|
|
def main(): |
|
st.title("Global Temperature Time Series") |
|
df = load_dataset("./monthly_csv.csv") |
|
|
|
tab1, tab2 = st.tabs(["KMeans Prediction", "About the Dataset"]) |
|
|
|
with tab1: |
|
st.header("Input Data") |
|
|
|
source = st.selectbox("Choose your source platform.", tuple(df['Source'].unique())) |
|
mean_temp = st.slider("Choose your avg. temp", df['Mean'].min(), df['Mean'].max()) |
|
date = st.date_input("Choose date to monitor air quality", min_value="1980-01-01", max_value=None) |
|
|
|
|
|
model = None |
|
if st.button("Predict Input!"): |
|
d = pd.DataFrame({"Source": [source], |
|
"Mean": [mean_temp], |
|
"Date": [date] |
|
}) |
|
input_X = process_data(d) |
|
|
|
model = load_model('./models/kmeans_model.pkl') |
|
if model is not None: |
|
processed_df = process_data(df) |
|
plot_clusters(model, processed_df, input_X) |
|
|
|
with tab2: |
|
st.caption("Global Temperature Time Series. Data are included from the GISS Surface Temperature (GISTEMP) analysis and the global component of Climate at a Glance (GCAG). Two datasets are provided: 1) global monthly mean and 2) annual mean temperature anomalies in degrees Celsius from 1880 to the present.") |
|
|
|
st.write("Citation:") |
|
st.caption("GISTEMP: NASA Goddard Institute for Space Studies (GISS) Surface Temperature Analysis, Global Land-Ocean Temperature Index.") |
|
st.caption("NOAA National Climatic Data Center (NCDC), global component of Climate at a Glance (GCAG).)") |
|
|
|
if __name__ == "__main__": |
|
main() |