hydraadra112's picture
Changed App Title
91edfd4
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import joblib
@st.cache_resource
def load_dataset(path: str) -> pd.DataFrame:
"""
Opens up a `.csv` file as our main dataset.
Args:
path (str): Path to the dataset to be opened
Returns:
pd.DataFrame: A pandas dataframe
"""
df = pd.read_csv(path)
dates = pd.to_datetime(df['Date'])
df['Date'] = dates
df['Year'] = df['Date'].dt.year
df['Month'] = df['Date'].dt.month
df['Day'] = df['Date'].dt.day
df.drop('Date', axis=1, inplace=True)
return df
@st.cache_resource
def load_model(model_path: str) -> KMeans:
return joblib.load(model_path)
def process_data(input_X: pd.DataFrame) -> np.array:
"""
Processes user input data into usable form for the KMeans model to predict.
Args:
input_X (pd.DataFrame): Input data in dataframe format with one instance.
Returns:
np.array: An numpy array for the KMeans model to predict
"""
input_X = input_X.copy()
# Split up the dates
if 'Date' in input_X.columns:
dates = pd.to_datetime(input_X['Date'])
input_X['Date'] = dates
input_X['Year'] = input_X['Date'].dt.year
input_X['Month'] = input_X['Date'].dt.month
input_X['Day'] = input_X['Date'].dt.day
input_X.drop('Date', axis=1, inplace=True)
input_X = pd.get_dummies(input_X, prefix=['Source'], dtype=int)
for col in ['Source_GCAG', 'Source_GISTEMP']:
if col not in input_X.columns:
input_X[col] = 0
# Reorder columns to ensure correct order
input_X = input_X[['Mean', 'Year', 'Month', 'Day', 'Source_GCAG', 'Source_GISTEMP']]
arr_X = input_X.to_numpy()
return arr_X
def plot_clusters(model: KMeans, X: np.array, input_X: np.array) -> None:
"""
Plots the predicted class to the clusters.
Args:
model (KMeans): A KMeans model trained on X input
X (np.array): The numpy array version of the dataset
input_X (np.array): The numpy array of the input
Returns:
None
"""
centroids = model.cluster_centers_
labels = model.labels_
fig = plt.figure(figsize=(10,6))
ax = fig.subplots()
for cluster in range(3):
cluster_points = X[labels == cluster]
plt.scatter(cluster_points[:, 0], cluster_points[:, 1], label=f'Cluster {cluster}')
# Plot centroids
ax.scatter(centroids[:, 0], centroids[:, 1], s=200, c='black', marker='X', label='Centroids')
predictions = model.predict(input_X)
st.write(f"Predicted Cluster: {predictions[0]}")
# Highlight the predicted cluster and point
ax.scatter(input_X[:, 0], input_X[:, 1], s=300, c='red', marker='P', label=f'Predicted Cluster: {predictions[0]}')
ax.set_title('K-Means Clustering with Predicted Point')
ax.legend()
st.pyplot(fig)
def main():
st.title("Global Temperature Time Series")
df = load_dataset("./monthly_csv.csv")
tab1, tab2 = st.tabs(["KMeans Prediction", "About the Dataset"])
with tab1:
st.header("Input Data")
source = st.selectbox("Choose your source platform.", tuple(df['Source'].unique()))
mean_temp = st.slider("Choose your avg. temp", df['Mean'].min(), df['Mean'].max())
date = st.date_input("Choose date to monitor air quality", min_value="1980-01-01", max_value=None)
model = None
if st.button("Predict Input!"):
d = pd.DataFrame({"Source": [source],
"Mean": [mean_temp],
"Date": [date]
})
input_X = process_data(d)
model = load_model('./models/kmeans_model.pkl')
if model is not None:
processed_df = process_data(df)
plot_clusters(model, processed_df, input_X)
with tab2:
st.caption("Global Temperature Time Series. Data are included from the GISS Surface Temperature (GISTEMP) analysis and the global component of Climate at a Glance (GCAG). Two datasets are provided: 1) global monthly mean and 2) annual mean temperature anomalies in degrees Celsius from 1880 to the present.")
st.write("Citation:")
st.caption("GISTEMP: NASA Goddard Institute for Space Studies (GISS) Surface Temperature Analysis, Global Land-Ocean Temperature Index.")
st.caption("NOAA National Climatic Data Center (NCDC), global component of Climate at a Glance (GCAG).)")
if __name__ == "__main__":
main()