import functools
import traceback
import gradio as gr
import bittensor as bt
from typing import Dict, List, Any, Optional, Tuple
from bittensor.extrinsics.serving import get_metadata
from dataclasses import dataclass
import wandb
import math
import os
import datetime
import time
import json
import pandas as pd
from dotenv import load_dotenv
from huggingface_hub import HfApi
from apscheduler.schedulers.background import BackgroundScheduler
import pandas as pd

load_dotenv()

FONT = (
    """<link href="https://fonts.cdnfonts.com/css/jmh-typewriter" rel="stylesheet">"""
)
TITLE = """<h1 align="center" id="space-title" class="typewriter">Subnet 9 Leaderboard</h1>"""
HEADER = """<h2 align="center" class="typewriter"><a href="https://github.com/RaoFoundation/pretraining" target="_blank">Subnet 9</a> is a <a href="https://bittensor.com/" target="_blank">Bittensor</a> subnet that rewards miners for producing pretrained Foundation-Models on the <a href="https://huggingface.co/datasets/tiiuae/falcon-refinedweb" target="_blank">Falcon Refined Web dataset</a>. It acts like a continuous benchmark whereby miners are rewarded for attaining the best losses on randomly sampled pages of Falcon.<br/>The models with the best head-to-head loss on the evaluation data receive a steady emission of TAO.</h3>"""
EVALUATION_DETAILS = """<ul><li><b>Name:</b> the 🤗 Hugging Face model name (click to go to the model card)</li><li><b>Rewards / Day:</b> the expected rewards per day based on current ranking.</li><li><b>Last Average Loss:</b> the last loss value on the evaluation data for the model as calculated by a validator (lower is better)</li><li><b>UID:</b> the Bittensor UID of the miner</li><li><b>Block:</b> the Bittensor block that the model was submitted in</li></ul><br/>More stats on <a href="https://taostats.io/subnets/netuid-9/" target="_blank">taostats</a>."""
EVALUATION_HEADER = """<h3 align="center">Shows the latest internal evaluation statistics as calculated by the Opentensor validator</h3>"""
VALIDATOR_WANDB_PROJECT = "opentensor-dev/pretraining-subnet"
BENCHMARK_WANDB_PROJECT = "raofoundation/pretraining-leaderboard-data"
H4_TOKEN = os.environ.get("H4_TOKEN", None)
API = HfApi(token=H4_TOKEN)
WANDB_TOKEN = os.environ.get("WANDB_API_KEY", None)
REPO_ID = "0x9/pretraining-leaderboard"
MAX_AVG_LOSS_POINTS = 1
RETRIES = 5
DELAY_SECS = 3
NETUID = 9
SECONDS_PER_BLOCK = 12


@dataclass
class ModelData:
    uid: int
    hotkey: str
    namespace: str
    name: str
    commit: str
    hash: str
    block: int
    incentive: float
    emission: float

    @classmethod
    def from_compressed_str(
        cls,
        uid: int,
        hotkey: str,
        cs: str,
        block: int,
        incentive: float,
        emission: float,
    ):
        """Returns an instance of this class from a compressed string representation"""
        tokens = cs.split(":")
        return ModelData(
            uid=uid,
            hotkey=hotkey,
            namespace=tokens[0],
            name=tokens[1],
            commit=tokens[2] if tokens[2] != "None" else None,
            hash=tokens[3] if tokens[3] != "None" else None,
            block=block,
            incentive=incentive,
            emission=emission,
        )


def run_with_retries(func, *args, **kwargs):
    for i in range(0, RETRIES):
        try:
            return func(*args, **kwargs)
        except (Exception, RuntimeError):
            if i == RETRIES - 1:
                raise
            time.sleep(DELAY_SECS)
    raise RuntimeError("Should never happen")


def get_subtensor_and_metagraph() -> Tuple[bt.subtensor, bt.metagraph]:
    def _internal() -> Tuple[bt.subtensor, bt.metagraph]:
        subtensor = bt.subtensor("ws://finney.cortex.foundation:9944")
        metagraph = bt.metagraph(NETUID, lite=False)
        return subtensor, metagraph

    return run_with_retries(_internal)


def get_validator_weights(
    metagraph: bt.metagraph,
) -> Dict[int, Tuple[float, int, Dict[int, float]]]:
    """Returns a dictionary of validator UIDs to (vtrust, stake, {uid: weight})."""
    ret = {}
    for uid in metagraph.uids.tolist():
        vtrust = metagraph.validator_trust[uid].item()
        if vtrust > 0:
            ret[uid] = (vtrust, metagraph.S[uid].item(), {})
            for ouid in metagraph.uids.tolist():
                if ouid == uid:
                    continue
                weight = round(metagraph.weights[uid][ouid].item(), 4)
                if weight > 0:
                    ret[uid][-1][ouid] = weight
    return ret


def get_subnet_data(
    subtensor: bt.subtensor, metagraph: bt.metagraph
) -> List[ModelData]:
    result = []
    for uid in metagraph.uids.tolist():
        hotkey = metagraph.hotkeys[uid]
        metadata = None
        try:
            metadata = run_with_retries(
                functools.partial(get_metadata, subtensor, metagraph.netuid, hotkey)
            )
        except:
            print(f"Failed to get metadata for UID {uid}: {traceback.format_exc()}")

        if not metadata:
            continue

        commitment = metadata["info"]["fields"][0]
        hex_data = commitment[list(commitment.keys())[0]][2:]
        chain_str = bytes.fromhex(hex_data).decode()
        block = metadata["block"]
        incentive = metagraph.incentive[uid].nan_to_num().item()
        emission = (
            metagraph.emission[uid].nan_to_num().item() * 20
        )  # convert to daily TAO

        model_data = None
        try:
            model_data = ModelData.from_compressed_str(
                uid, hotkey, chain_str, block, incentive, emission
            )
        except:
            continue

        result.append(model_data)
    return result


def is_floatable(x) -> bool:
    return (
        isinstance(x, float) and not math.isnan(x) and not math.isinf(x)
    ) or isinstance(x, int)


def get_wandb_runs(project: str, filters: Dict[str, Any]) -> List:
    """Get the latest runs from Wandb, retrying infinitely until we get them."""
    while True:
        api = wandb.Api(api_key=WANDB_TOKEN)
        runs = list(
            api.runs(
                project,
                filters=filters,
            )
        )
        if len(runs) > 0:
            return runs
        # WandDB API is quite unreliable. Wait another minute and try again.
        print("Failed to get runs from Wandb. Trying again in 60 seconds.")
        time.sleep(60)


def get_scores(
    uids: List[int],
    wandb_runs: List,
) -> Dict[int, Dict[str, Optional[float]]]:
    result = {}
    previous_timestamp = None
    # Iterate through the runs until we've processed all the uids.
    for i, run in enumerate(wandb_runs):
        if not "original_format_json" in run.summary:
            continue
        data = json.loads(run.summary["original_format_json"])
        all_uid_data = data["uid_data"]
        timestamp = data["timestamp"]

        # Make sure runs are indeed in descending time order.
        assert (
            previous_timestamp is None or timestamp < previous_timestamp
        ), f"Timestamps are not in descending order: {timestamp} >= {previous_timestamp}"
        previous_timestamp = timestamp

        for uid in uids:
            if uid in result:
                continue
            if str(uid) in all_uid_data:
                uid_data = all_uid_data[str(uid)]
                # Only the most recent run is fresh.
                is_fresh = i == 0
                result[uid] = {
                    "avg_loss": uid_data.get("average_loss", None),
                    "win_rate": uid_data.get("win_rate", None),
                    "win_total": uid_data.get("win_total", None),
                    "weight": uid_data.get("weight", None),
                    "fresh": is_fresh,
                }
        if len(result) == len(uids):
            break
    return result


def get_losses_over_time(wandb_runs: List) -> pd.DataFrame:
    """Returns a dataframe of the best average model loss over time."""
    timestamps = []
    best_losses = []

    for run in wandb_runs:
        if "original_format_json" not in run.summary:
            continue
        data = json.loads(run.summary["original_format_json"])
        all_uid_data = data["uid_data"]
        timestamp = datetime.datetime.fromtimestamp(data["timestamp"])
        best_loss = math.inf
        for _, uid_data in all_uid_data.items():
            loss = uid_data.get("average_loss", math.inf)
            # Filter out the numbers from the exploit.
            if loss < best_loss and (loss > 2.5 or timestamp > datetime.datetime(2024,2,8)):
                best_loss = uid_data["average_loss"]
        if best_loss != math.inf:
            timestamps.append(timestamp)
            best_losses.append(best_loss)

    return pd.DataFrame({"timestamp": timestamps, "best_loss": best_losses})


def format_score(uid: int, scores, key) -> Optional[float]:
    if uid in scores:
        if key in scores[uid]:
            point = scores[uid][key]
            if is_floatable(point):
                return round(scores[uid][key], 4)
    return None


def next_epoch(subtensor: bt.subtensor, block: int) -> int:
    return 0


def get_next_update_div(current_block: int, next_update_block: int) -> str:
    now = datetime.datetime.now()
    blocks_to_go = next_update_block - current_block
    next_update_time = now + datetime.timedelta(
        seconds=blocks_to_go * SECONDS_PER_BLOCK
    )
    delta = next_update_time - now
    return f"""<div align="center" style="font-size: larger;">Next reward update: <b>{blocks_to_go}</b> blocks (~{int(delta.total_seconds() // 60)} minutes)</div>"""


def get_last_updated_div() -> str:
    return f"""<div>Last Updated: {datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")} (UTC)</div>"""


def leaderboard_data(
    leaderboard: List[ModelData],
    scores: Dict[int, Dict[str, Optional[float]]],
    show_stale: bool,
) -> List[List[Any]]:
    """Returns the leaderboard data, based on models data and UID scores."""
    return [
        [
            f"[{c.namespace}/{c.name} ({c.commit[0:8]})](https://huggingface.co/{c.namespace}/{c.name}/commit/{c.commit})",
            format_score(c.uid, scores, "win_rate"),
            format_score(c.uid, scores, "avg_loss"),
            format_score(c.uid, scores, "weight"),
            c.uid,
            c.block,
        ]
        for c in leaderboard
        if (c.uid in scores and scores[c.uid]["fresh"]) or show_stale
    ]
    
def get_benchmarks() -> Tuple[pd.DataFrame, datetime.datetime]:
    """Returns the latest benchmarks and the time they were run."""
    return None, None
    runs = get_wandb_runs(project=BENCHMARK_WANDB_PROJECT, filters=None)
    for run in runs:
        artifacts = list(run.logged_artifacts())
        if artifacts:
            table = artifacts[-1].get("benchmarks")
            if table:
                return table.get_dataframe(), datetime.datetime.strptime(run.metadata["startedAt"], "%Y-%m-%dT%H:%M:%S.%f")
    bt.logging.error("Failed to get benchmarks from Wandb.")
    return None, None


def restart_space():
    API.restart_space(repo_id=REPO_ID, token=H4_TOKEN)


def main():
    # To avoid leaderboard failures, infinitely try until we get all data
    # needed to populate the dashboard
    while True:
        try:
            subtensor, metagraph = get_subtensor_and_metagraph()

            model_data: List[ModelData] = get_subnet_data(subtensor, metagraph)
            model_data.sort(key=lambda x: x.incentive, reverse=True)

            vali_runs = get_wandb_runs(project=VALIDATOR_WANDB_PROJECT, filters={"config.type": "validator", "config.uid": 238})

            scores = get_scores([x.uid for x in model_data], vali_runs)

            current_block = metagraph.block.item()
            next_epoch_block = next_epoch(subtensor, current_block)

            validator_df = get_validator_weights(metagraph)
            weight_keys = set()
            for uid, stats in validator_df.items():
                weight_keys.update(stats[-1].keys())
                
            benchmarks, benchmark_timestamp = get_benchmarks()
            break
        except Exception as e:
            print(f"Failed to get data: {e}")
            time.sleep(30)            

    demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}")
    with demo:
        gr.HTML(FONT)
        gr.HTML(TITLE)
        gr.HTML(HEADER)

        gr.HTML(value=get_next_update_div(current_block, next_epoch_block))

        gr.Label(
            value={
                f"{c.namespace}/{c.name} ({c.commit[0:8]}) · (τ{round(c.emission, 2):,})": c.incentive
                for c in model_data
                if c.incentive
            },
            num_top_classes=10,
        )
        
        if benchmarks is not None:
            with gr.Accordion("Top Model Benchmarks"):
                gr.components.Dataframe(benchmarks)
                gr.HTML("""<div>PPL computed using a stride of 512. See <a href='https://github.com/RaoFoundation/pretraining/blob/dev/scripts/run_benchmarks.py'>here</a> for the full code.</div>""")
                gr.HTML(f"""<div>Last Updated: {benchmark_timestamp.strftime("%Y-%m-%d %H:%M:%S")} (UTC)</div>""")

        with gr.Accordion("Evaluation Stats"):
            gr.HTML(EVALUATION_HEADER)
            show_stale = gr.Checkbox(label="Show Stale", interactive=True)
            leaderboard_table = gr.components.Dataframe(
                value=leaderboard_data(model_data, scores, show_stale.value),
                headers=["Name", "Win Rate", "Average Loss", "Weight", "UID", "Block"],
                datatype=["markdown", "number", "number", "number", "number", "number"],
                elem_id="leaderboard-table",
                interactive=False,
                visible=True,
            )
            gr.HTML(EVALUATION_DETAILS)
            show_stale.change(
                lambda stale: leaderboard_data(model_data, scores, stale),
                inputs=[show_stale],
                outputs=leaderboard_table,
            )

            gr.LinePlot(
                get_losses_over_time(vali_runs),
                x="timestamp",
                x_title="Date",
                y="best_loss",
                y_title="Average Loss",
                tooltip="best_loss",
                interactive=True,
                visible=True,
                width=1024,
                title="Best Average Loss Over Time",
            )

        with gr.Accordion("Validator Stats"):
            gr.components.Dataframe(
                value=[
                    [uid, int(validator_df[uid][1]), round(validator_df[uid][0], 4)]
                    + [
                        validator_df[uid][-1].get(c.uid)
                        for c in model_data
                        if c.incentive
                    ]
                    for uid, _ in sorted(
                        zip(
                            validator_df.keys(),
                            [validator_df[x][1] for x in validator_df.keys()],
                        ),
                        key=lambda x: x[1],
                        reverse=True,
                    )
                ],
                headers=["UID", "Stake (Ï„)", "V-Trust"]
                + [
                    f"{c.namespace}/{c.name} ({c.commit[0:8]})"
                    for c in model_data
                    if c.incentive
                ],
                datatype=["number", "number", "number"]
                + ["number" for c in model_data if c.incentive],
                interactive=False,
                visible=True,
            )
        gr.HTML(value=get_last_updated_div())

    scheduler = BackgroundScheduler()
    scheduler.add_job(
        restart_space, "interval", seconds=60 * 30
    )  # restart every 15 minutes
    scheduler.start()

    demo.launch()


main()