import dataclasses
from multiprocessing import cpu_count
import tqdm
import requests
import streamlit as st

import pandas as pd
from datasets import Dataset, load_dataset
from paperswithcode import PapersWithCodeClient


@dataclasses.dataclass(frozen=True)
class PaperInfo:
    date: str
    arxiv_id: str
    github: str
    title: str
    paper_page: str
    upvotes: int
    num_comments: int


def get_df(start_date: str = None, end_date: str = None) -> pd.DataFrame:
    """
    Load the initial dataset as a Pandas dataframe.

    One can optionally specify a start_date and end_date to only include data between these dates.
    """

    df = pd.merge(
        left=load_dataset("hysts-bot-data/daily-papers", split="train").to_pandas(),
        right=load_dataset("hysts-bot-data/daily-papers-stats", split="train").to_pandas(),
        on="arxiv_id",
    )
    df = df[::-1].reset_index(drop=True)

    paper_info = []
    for _, row in tqdm.auto.tqdm(df.iterrows(), total=len(df)):
        info = PaperInfo(
            **row,
            paper_page=f"https://huggingface.co/papers/{row.arxiv_id}",
        )
        paper_info.append(info)
    
    df = pd.DataFrame([dataclasses.asdict(info) for info in paper_info])

    # set date as index
    df = df.set_index('date')
    df.index = pd.to_datetime(df.index)
    if start_date is not None and end_date is not None:
      # only include data between start_date and end_date
      df = df[(df.index >= start_date) & (df.index <= end_date)]

    return df


def get_github_url(client: PapersWithCodeClient, paper_title: str) -> str:
  """
  Get the Github URL for a paper.
  """

  repo_url = ""
  try:
    # get paper ID
    results = client.paper_list(q=paper_title).results
    paper_id = results[0].id

    # get paper
    paper = client.paper_get(paper_id=paper_id)

    # get repositories
    repositories = client.paper_repository_list(paper_id=paper.id).results

    for repo in repositories:
      if repo.is_official:
        repo_url = repo.url

  except:
    pass

  return repo_url


def add_metadata_batch(batch, client: PapersWithCodeClient):
    """
    Add metadata to a batch of papers.
    """

    # get Github URLs for all papers in the batch
    github_urls = []
    for paper_title in batch["title"]:
        github_url = get_github_url(client, paper_title)
        github_urls.append(github_url)

    # overwrite the Github links
    batch["github"] = github_urls

    return batch


def add_hf_assets(batch):
    """
    Add Hugging Face assets to a batch of papers.
    """
    num_spaces = []
    num_models = []
    num_datasets = []
    for arxiv_id in batch["arxiv_id"]:
        if arxiv_id != "":
            response = requests.get(f"https://huggingface.co/api/arxiv/{arxiv_id}/repos")
            result = response.json()
            num_spaces_example = len(result["spaces"])
            num_models_example = len(result["models"])
            num_datasets_example = len(result["datasets"])
        else:
            num_spaces_example = 0
            num_models_example = 0
            num_datasets_example = 0

        num_spaces.append(num_spaces_example)
        num_models.append(num_models_example)
        num_datasets.append(num_datasets_example)

    batch["num_models"] = num_models
    batch["num_datasets"] = num_datasets
    batch["num_spaces"] = num_spaces

    return batch


def check_hf_mention(batch):
    """
    Check if a paper mentions Hugging Face in the README.
    """

    hf_mentions = []
    for github_url in batch["github"]:
        hf_mention = 0
        if github_url != "":
            # get README text using Github API
            owner = github_url.split("/")[-2]
            repo = github_url.split("/")[-1]
            branch = "main"
            url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/README.md"
            response = requests.get(url)
            
            if response.status_code != 200:
                # try master branch as second attempt
                branch = "master"
                url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/README.md"
                response = requests.get(url)
            
            if response.status_code == 200:
                # get text
                text = response.text
                if "huggingface" in text.lower() or "hugging face" in text.lower():
                    hf_mention = 1
            
        hf_mentions.append(hf_mention)

    # overwrite the Github links
    batch["hf_mention"] = hf_mentions

    return batch


def process_data(start_date: str, end_date: str) -> pd.DataFrame:
    """
    Load the dataset and enrich it with metadata.
    """
    # step 1. load as HF dataset
    df = get_df(start_date, end_date)
    dataset = Dataset.from_pandas(df)

    # step 2. enrich using PapersWithCode API
    dataset = dataset.map(add_metadata_batch, batched=True, batch_size=4, num_proc=cpu_count(), fn_kwargs={"client": PapersWithCodeClient()})

    # step 3. enrich using Hugging Face API
    dataset = dataset.map(add_hf_assets, batched=True, batch_size=4, num_proc=cpu_count())

    # step 4. check if Hugging Face is mentioned in the README
    dataset = dataset.map(check_hf_mention, batched=True, batch_size=4, num_proc=cpu_count())

    # return as Pandas dataframe
    # making sure that the date is set as index
    dataframe = dataset.to_pandas()
    dataframe = dataframe.set_index('date')
    dataframe.index = pd.to_datetime(dataframe.index)

    return dataframe


@st.cache_data
def get_data() -> pd.DataFrame:

    # step 1: load pre-processed data
    df = load_dataset("nielsr/daily-papers-enriched", split="train").to_pandas()
    df = df.set_index('date')
    df = df.sort_index()
    df.index = pd.to_datetime(df.index)

    # step 2: check how much extra data we need to process
    latest_day = df.iloc[-1].name.strftime('%Y-%m-%d')
    today = pd.Timestamp.today().strftime('%Y-%m-%d')

    print("Latest day:", latest_day)
    print("Today:", today)

    # step 3: process the missing data
    if latest_day < today:
        print(f"Processing data from {latest_day} to {today}")
        new_df = process_data(start_date=latest_day, end_date=today)

        print("Original df:", df.head())
        print("New df:", new_df.head())

        df = pd.concat([df, new_df])

    df = df.sort_index()

    return df