import datetime
from huggingface_hub import Repository
import os
import pandas as pd
import streamlit as st
import altair as alt
import numpy as np
import plotly.graph_objects as go

today = datetime.date.today()
year, week, _ = today.isocalendar()

DATASET_REPO_URL = (
    "https://huggingface.co/datasets/huggingface/transformers-stats-space-data"
)

DATA_FILENAME = f"data_{week}_{year}.csv"
DATA_FILE = os.path.join("data", DATA_FILENAME)

MODELS_TO_TRACK = ["wav2vec2", "whisper"]

repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL)
repo.git_pull()

valid_weeks = []
download_results = []
model_download_results = {model_name: [] for model_name in MODELS_TO_TRACK}

# loop over past data, finding where we have data saved (valid weeks) and tracking monthly downloads for each week
for i in range(1, week + 1)[::-1]:
    data_filename = f"data_{i}_{year}.csv"
    data_file = os.path.join("data", data_filename)

    if os.path.exists(data_file):
        valid_weeks.append(i)

        dataframe = pd.read_csv(data_file)
        df_audio = dataframe[dataframe["modality"] == "audio"]

        audio_int_downloads = {model: int(x.replace(",", "")) for model, x in
                               zip(df_audio["model_names"], df_audio["num_downloads"].values)}

        download_results.append(sum(audio_int_downloads.values()))
        for model_name in MODELS_TO_TRACK:
            model_download_results[model_name].append(audio_int_downloads.get(model_name))

last_year = year - 1
last_week = 52
data_filename = f"data_{last_week}_{last_year}.csv"
data_file = os.path.join("data", data_filename)

if os.path.exists(data_file):
    valid_weeks.append(0)

    dataframe = pd.read_csv(data_file)
    df_audio = dataframe[dataframe["modality"] == "audio"]

    audio_int_downloads = {model: int(x.replace(",", "")) for model, x in
                           zip(df_audio["model_names"], df_audio["num_downloads"].values)}

    download_results.append(sum(audio_int_downloads.values()))
    for model_name in MODELS_TO_TRACK:
        model_download_results[model_name].append(audio_int_downloads.get(model_name))

fig = go.Figure()
fig.update_layout(
    title="Monthly downloads",
    xaxis_title="Week",
    yaxis_title="Downloads",)

fig.add_trace(
            go.Scatter(x=valid_weeks, y=download_results, mode='lines+markers', name="Total")
        )

for model_name in MODELS_TO_TRACK:
    fig.add_trace(
                go.Scatter(x=valid_weeks, y=model_download_results[model_name], mode='lines+markers', name=model_name)
            )

st.title("Audio Stats")
st.plotly_chart(fig)


week = st.selectbox(
    "Week",
    valid_weeks,
    index=0,
    help="Filter the download results by week"
)

DATA_FILENAME = f"data_{week}_{year}.csv"
DATA_FILE = os.path.join("data", DATA_FILENAME)

with open(DATA_FILE, "r") as f:
    dataframe = pd.read_csv(DATA_FILE)

st.header(f"Stats for year {year} and week {week}")

# print audio
df_audio = dataframe[dataframe["modality"] == "audio"]
audio_int_downloads = np.array(
    [int(x.replace(",", "")) for x in df_audio["num_downloads"].values]
)
source = pd.DataFrame(
    {
        "Number of total downloads": audio_int_downloads,
        "Model architecture name": df_audio["model_names"].values,
    }
)
bar_chart = (
    alt.Chart(source)
    .mark_bar()
    .encode(
        y="Number of total downloads",
        x=alt.X("Model architecture name", sort=None),
    )
)
st.subheader(f"Top audio downloads last 30 days")
st.altair_chart(bar_chart, use_container_width=True)

st.subheader("Audio stats last 30 days")

dataframe = dataframe[dataframe["modality"] == "audio"].drop("modality", axis=1)
dataframe.loc["Total"] = dataframe.sum(numeric_only=True)
total_audio_downloads = sum(audio_int_downloads)

# nice formatting
dataframe.at["Total", "num_downloads"] = "{:,}".format(total_audio_downloads)
dataframe.at["Total", "model_names"] = ""
dataframe.at["Total", "download_per_model"] = ""

st.table(dataframe)