import json
import os
from datetime import datetime, timedelta

import pandas as pd
from datasets import load_dataset
from huggingface_hub import hf_hub_download, list_repo_tree

import config


def load_raw_rewriting_as_pandas():
    return load_dataset(
        config.HF_RAW_DATASET_NAME, split=config.HF_RAW_DATASET_SPLIT, token=config.HF_TOKEN, cache_dir=config.CACHE_DIR
    ).to_pandas()


def load_full_commit_as_pandas():
    return (
        load_dataset(
            path=config.HF_FULL_COMMITS_DATASET_NAME,
            name=config.HF_FULL_COMMITS_DATASET_SUBNAME,
            split=config.HF_FULL_COMMITS_DATASET_SPLIT,
            cache_dir=config.CACHE_DIR,
        )
        .to_pandas()
        .rename(columns={"message": "reference"})
    )


def edit_time_from_history(history_str):
    history = json.loads(history_str)

    if len(history) == 0:
        return 0

    timestamps = list(map(lambda e: datetime.fromisoformat(e["ts"]), history))
    delta = max(timestamps) - min(timestamps)

    return delta // timedelta(milliseconds=1)


def edit_time_from_timestamps(row):
    loaded_ts = datetime.fromisoformat(row["loaded_ts"])
    submitted_ts = datetime.fromisoformat(row["submitted_ts"])

    delta = submitted_ts - loaded_ts

    result = delta // timedelta(milliseconds=1)

    return result if result >= 0 else None


def load_processed_rewriting_as_pandas():
    manual_rewriting = load_raw_rewriting_as_pandas()[
        [
            "hash",
            "repo",
            "commit_msg_start",
            "commit_msg_end",
            "session",
            "commit_msg_history",
            "loaded_ts",
            "submitted_ts",
        ]
    ]

    manual_rewriting["edit_time_hist"] = manual_rewriting["commit_msg_history"].apply(edit_time_from_history)
    manual_rewriting["edit_time"] = manual_rewriting.apply(edit_time_from_timestamps, axis=1)

    manual_rewriting.drop(columns=["commit_msg_history", "loaded_ts", "submitted_ts"])

    manual_rewriting.set_index(["hash", "repo"], inplace=True)

    mods_dataset = load_full_commit_as_pandas()[["hash", "repo", "mods"]]
    mods_dataset.set_index(["hash", "repo"], inplace=True)

    return manual_rewriting.join(other=mods_dataset, how="left").reset_index()


def load_synthetic_as_pandas():
    return load_dataset(
        config.HF_SYNTHETIC_DATASET_NAME,
        "all_pairs_with_metrics",
        split=config.HF_SYNTHETIC_DATASET_SPLIT,
        token=config.HF_TOKEN,
        cache_dir=config.CACHE_DIR,
    ).to_pandas()


def load_full_commit_with_predictions_as_pandas():
    full_dataset = load_full_commit_as_pandas()

    predictions_paths = []
    for prediction_file in list_repo_tree(
        repo_id=config.HF_PREDICTIONS_DATASET_NAME,
        path=os.path.join("commit_message_generation/predictions", config.HF_PREDICTIONS_MODEL),
        repo_type="dataset",
    ):
        predictions_paths.append(
            hf_hub_download(
                prediction_file.path,
                repo_id=config.HF_PREDICTIONS_DATASET_NAME,
                repo_type="dataset",
                cache_dir=config.CACHE_DIR,
            )
        )

    dfs = []
    for path in predictions_paths:
        dfs.append(pd.read_json(path, orient="records", lines=True))
    predictions_dataset = pd.concat(dfs, axis=0, ignore_index=True)
    predictions_dataset = predictions_dataset.sample(frac=1, random_state=config.RANDOM_STATE).set_index(
        ["hash", "repo"]
    )[["prediction"]]
    predictions_dataset = predictions_dataset[~predictions_dataset.index.duplicated(keep="first")]

    dataset = full_dataset.join(other=predictions_dataset, on=("hash", "repo"))

    return dataset.reset_index()