import gradio as gr
import bittensor as bt
import typing
from bittensor.extrinsics.serving import get_metadata
from dataclasses import dataclass
import requests
import wandb
import math
import os
import datetime
import time
import functools
import multiprocessing
from dotenv import load_dotenv
from huggingface_hub import HfApi
from apscheduler.schedulers.background import BackgroundScheduler
from tqdm import tqdm
import concurrent.futures
import sys
import numpy as np

load_dotenv()

FONT = (
    """<link href="https://fonts.cdnfonts.com/css/jmh-typewriter" rel="stylesheet">"""
)
TITLE = """<h1 align="center" id="space-title" class="typewriter">MyShell TTS Subnet Leaderboard</h1>"""
IMAGE = """<a href="https://discord.gg/myshell" target="_blank"><img src="https://avatars.githubusercontent.com/u/127754094?s=2000&v=4" alt="MyShell" style="margin: auto; width: 20%; border: 0;" /></a>"""
HEADER = """<h2 align="center" class="typewriter">MyShell TTS Subnet is a groundbreaking project that leverages the power of decentralized collaboration to advance the state-of-the-art in open-source Text-to-Speech (TTS) technology. By harnessing the Bittensor blockchain and a unique incentive mechanism, we aim to create the most advanced and accessible TTS models. By leveraging MyShell's user base of over one million individuals, we are devoted to pushing cutting-edge technology to every end-user.</h3>"""
EVALUATION_DETAILS = """<b>Name</b> is the 🤗 Hugging Face model name (click to go to the model card). <b>Rewards / Day</b> are the expected rewards per day for each model. <b>Block</b> is the Bittensor block that the model was submitted in. More stats on <a href="https://taostats.io/subnets/netuid-3/" target="_blank">taostats</a>."""
EVALUATION_HEADER = """<h3 align="center">Shows the latest internal evaluation statistics as calculated by a validator run by MyShell, the results are just for reference.  </h3>"""
VALIDATOR_WANDB_PROJECT = "myshell_tc/tts_subnet_validator"
# os.environ.get("VALIDATOR_WANDB_PROJECT")
H4_TOKEN = os.environ.get("H4_TOKEN", None)
API = HfApi(token=H4_TOKEN)
REPO_ID = "myshell-test/tts-subnet-leaderboard"
METAGRAPH_RETRIES = 10
METAGRAPH_DELAY_SECS = 30
METADATA_TTL = 10
NETUID = 3
SUBNET_START_BLOCK = 2635801
SECONDS_PER_BLOCK = 12
SUBTENSOR = os.environ.get("SUBTENSOR", "finney")


@dataclass
class Competition:
    id: str
    name: str


COMPETITIONS = [
    Competition(id="p255", name="anispeech-speaker-new"),
    Competition(id="p257", name="anispeech-speaker-old"),
]
DEFAULT_COMPETITION_ID = "p255"
last_refresh = None


def run_in_subprocess(func: functools.partial, ttl: int) -> typing.Any:
    """Runs the provided function on a subprocess with 'ttl' seconds to complete.
    Args:
        func (functools.partial): Function to be run.
        ttl (int): How long to try for in seconds.
    Returns:
        Any: The value returned by 'func'
    """

    def wrapped_func(func: functools.partial, queue: multiprocessing.Queue):
        try:
            result = func()
            queue.put(result)
        except (Exception, BaseException) as e:
            # Catch exceptions here to add them to the queue.
            queue.put(e)

    # Use "fork" (the default on all POSIX except macOS), because pickling doesn't seem
    # to work on "spawn".
    ctx = multiprocessing.get_context("fork")
    queue = ctx.Queue()
    process = ctx.Process(target=wrapped_func, args=[func, queue])

    process.start()

    process.join(timeout=ttl)

    if process.is_alive():
        process.terminate()
        process.join()
        raise TimeoutError(f"Failed to {func.func.__name__} after {ttl} seconds")

    # Raises an error if the queue is empty. This is fine. It means our subprocess timed out.
    result = queue.get(block=False)

    # If we put an exception on the queue then raise instead of returning.
    if isinstance(result, Exception):
        raise result
    if isinstance(result, BaseException):
        raise Exception(f"BaseException raised in subprocess: {str(result)}")

    return result


def get_subtensor_and_metagraph() -> typing.Tuple[bt.subtensor, bt.metagraph]:
    for i in range(0, METAGRAPH_RETRIES):
        try:
            print("Connecting to subtensor...")
            subtensor: bt.subtensor = bt.subtensor(SUBTENSOR)
            print("Pulling metagraph...")
            metagraph: bt.metagraph = subtensor.metagraph(NETUID, lite=False)
            return subtensor, metagraph
        except:
            if i == METAGRAPH_RETRIES - 1:
                raise
            print(
                f"Error connecting to subtensor or pulling metagraph, retry {i + 1} of {METAGRAPH_RETRIES} in {METAGRAPH_DELAY_SECS} seconds..."
            )
            time.sleep(METAGRAPH_DELAY_SECS)
    raise RuntimeError()


@dataclass
class ModelData:
    uid: int
    hotkey: str
    namespace: str
    name: str
    commit: str
    hash: str
    block: int
    incentive: float
    emission: float
    competition: str

    @classmethod
    def from_compressed_str(
        cls,
        uid: int,
        hotkey: str,
        cs: str,
        block: int,
        incentive: float,
        emission: float,
    ):
        """Returns an instance of this class from a compressed string representation"""
        tokens = cs.split(":")
        return ModelData(
            uid=uid,
            hotkey=hotkey,
            namespace=tokens[0],
            name=tokens[1],
            commit=tokens[2] if tokens[2] != "None" else "",
            hash=tokens[3] if tokens[3] != "None" else "",
            competition=tokens[4]
            if len(tokens) > 4 and tokens[4] != "None"
            else DEFAULT_COMPETITION_ID,
            block=block,
            incentive=incentive,
            emission=emission,
        )


def get_tao_price() -> float:
    for i in range(0, METAGRAPH_RETRIES):
        try:
            return float(requests.get("https://api.mexc.com/api/v3/avgPrice?symbol=TAOUSDT").json()["price"])
        except:
            if i == METAGRAPH_RETRIES - 1:
                raise
            time.sleep(METAGRAPH_DELAY_SECS)
        raise RuntimeError()


def get_validator_weights(
    metagraph: bt.metagraph,
) -> typing.Dict[int, typing.Tuple[float, int, typing.Dict[int, float]]]:
    ret = {}
    for uid in metagraph.uids.tolist():
        vtrust = metagraph.validator_trust[uid].item()
        if vtrust > 0:
            ret[uid] = (vtrust, metagraph.S[uid].item(), {})
            for ouid in metagraph.uids.tolist():
                if ouid == uid:
                    continue
                weight = round(metagraph.weights[uid][ouid].item(), 4)
                if weight > 0:
                    ret[uid][-1][ouid] = weight
    return ret


def get_subnet_data(
    subtensor: bt.subtensor, metagraph: bt.metagraph
) -> typing.List[ModelData]:
    global last_refresh

    # Function to be executed in a thread
    def fetch_data(uid):
        hotkey = metagraph.hotkeys[uid]
        try:
            partial = functools.partial(
                get_metadata, subtensor, metagraph.netuid, hotkey
            )
            metadata = run_in_subprocess(partial, METADATA_TTL)
        except Exception as e:
            return None

        if not metadata:
            return None

        commitment = metadata["info"]["fields"][0]
        hex_data = commitment[list(commitment.keys())[0]][2:]
        chain_str = bytes.fromhex(hex_data).decode()
        block = metadata["block"]
        # incentive = metagraph.incentive[uid].nan_to_num().item()
        incentive = np.nan_to_num(metagraph.incentive[uid]).item()

        emission = (
            np.nan_to_num(metagraph.emission[uid]).item() * 20
            # metagraph.emission[uid].nan_to_num().item() * 20
        )  # convert to daily TAO

        try:
            model_data = ModelData.from_compressed_str(
                uid, hotkey, chain_str, block, incentive, emission
            )
        except Exception as e:
            return None
        return model_data

    # Use ThreadPoolExecutor to fetch data in parallel
    results = []
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Prepare the list of futures
        futures = [executor.submit(fetch_data, uid) for uid in metagraph.uids.tolist()]
        for future in tqdm(
            concurrent.futures.as_completed(futures),
            desc="Metadata for hotkeys",
            total=len(futures),
        ):
            result = future.result()
            if result:
                results.append(result)

    last_refresh = datetime.datetime.now()
    return results


def floatable(x) -> bool:
    return (
        isinstance(x, float) and not math.isnan(x) and not math.isinf(x)
    ) or isinstance(x, int)


def get_float_score(
    key: str, history, competition_id: str
) -> typing.Tuple[typing.Optional[float], bool]:
    if key in history and "competition_id" in history:
        data = list(history[key])
        if len(data) > 0:
            competitions = list(history["competition_id"])
            while True:
                if competitions.pop() != competition_id:
                    data.pop()
                    continue
                if floatable(data[-1]):
                    return float(data[-1]), True
                else:
                    data = [float(x) for x in data if floatable(x)]
                    if len(data) > 0:
                        return float(data[-1]), False
                break
    return None, False


def get_sample(
    uid, history, competition_id: str
) -> typing.Optional[typing.Tuple[str, str, str]]:
    prompt_key = f"sample_prompt_data.{uid}"
    response_key = f"sample_response_data.{uid}"
    truth_key = f"sample_truth_data.{uid}"
    if (
        prompt_key in history
        and response_key in history
        and truth_key in history
        and "competition_id" in history
    ):
        competitions = list(history["competition_id"])
        prompts = list(history[prompt_key])
        responses = list(history[response_key])
        truths = list(history[truth_key])
        while True:
            prompt = prompts.pop()
            response = responses.pop()
            truth = truths.pop()
            if competitions.pop() != competition_id:
                continue
            if (
                isinstance(prompt, str)
                and isinstance(response, str)
                and isinstance(truth, str)
            ):
                return prompt, response, truth
            break
    return None


def get_scores(
    uids: typing.List[int], competition_id: str
) -> typing.Dict[int, typing.Dict[str, typing.Optional[float | str]]]:
    api = wandb.Api()
    runs = list(api.runs(VALIDATOR_WANDB_PROJECT))

    result = {}
    for run in runs:
        history = run.history()
        for uid in uids:
            if uid in result.keys():
                continue
            win_rate, win_rate_fresh = get_float_score(
                f"win_rate_data.{uid}", history, competition_id
            )
            win_total, win_total_fresh = get_float_score(
                f"win_total_data.{uid}", history, competition_id
            )
            weight, weight_fresh = get_float_score(
                f"weight_data.{uid}", history, competition_id
            )
            sample = get_sample(uid, history, competition_id)
            result[uid] = {
                "win_rate": win_rate,
                "win_total": win_total,
                "weight": weight,
                "sample": sample,
                "fresh": win_rate_fresh and win_total_fresh,
            }
        if len(result.keys()) == len(uids):
            break
    return result


def format_score(uid, scores, key) -> typing.Optional[float]:
    if uid in scores:
        if key in scores[uid]:
            point = scores[uid][key]
            if floatable(point):
                return round(scores[uid][key], 4)
    return None


def next_tempo(start_block, tempo, block):
    start_num = start_block + tempo
    intervals = (block - start_num) // tempo
    nearest_num = start_num + ((intervals + 1) * tempo)
    return nearest_num


subtensor, metagraph = get_subtensor_and_metagraph()

tao_price = get_tao_price()

leaderboard_df = get_subnet_data(subtensor, metagraph)
leaderboard_df.sort(key=lambda x: x.incentive, reverse=True)

print(leaderboard_df)

competition_scores = {
    y.id: get_scores([x.uid for x in leaderboard_df if x.competition == y.id], y.id)
    for y in COMPETITIONS
}

current_block = metagraph.block.item()

next_update = next_tempo(
    SUBNET_START_BLOCK,
    360,
    current_block,
)
        
blocks_to_go = next_update - current_block
current_time = datetime.datetime.now()
next_update_time = current_time + datetime.timedelta(
    seconds=blocks_to_go * SECONDS_PER_BLOCK
)

validator_df = get_validator_weights(metagraph)
weight_keys = set()
for uid, stats in validator_df.items():
    weight_keys.update(stats[-1].keys())


def get_next_update():
    now = datetime.datetime.now()
    delta = next_update_time - now
    return f"""<div align="center" style="font-size: larger;">Next reward update: <b>{blocks_to_go}</b> blocks (~{int(delta.total_seconds() // 60)} minutes)</div>"""


def leaderboard_data(
    show_stale: bool,
    scores: typing.Dict[int, typing.Dict[str, typing.Optional[float | str]]],
    competition_id: str,
):
    value = [
        [
            f"[{c.namespace}/{c.name} ({c.commit[0:8]}, UID={c.uid})](https://huggingface.co/{c.namespace}/{c.name}/commit/{c.commit})",
            format_score(c.uid, scores, "win_rate"),
            format_score(c.uid, scores, "weight"),
            c.uid,
            c.block,
        ]
        for c in leaderboard_df
        if c.competition == competition_id and (scores[c.uid]["fresh"] or show_stale)
    ]
    return value


demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}")
with demo:
    gr.HTML(FONT)
    gr.HTML(TITLE)
    gr.HTML(IMAGE)
    gr.HTML(HEADER)

    gr.HTML(value=get_next_update())

    with gr.Tabs():
        for competition in COMPETITIONS:
            with gr.Tab(competition.name):
                scores = competition_scores[competition.id]
                print(scores)

                class_denominator = sum(
                    leaderboard_df[i].incentive
                    for i in range(0, min(10, len(leaderboard_df)))
                    if leaderboard_df[i].incentive
                    and leaderboard_df[i].competition == competition.id
                )

                class_values = {
                    f"{leaderboard_df[i].namespace}/{leaderboard_df[i].name} ({leaderboard_df[i].commit[0:8]}, UID={leaderboard_df[i].uid}) · ${round(leaderboard_df[i].emission * tao_price, 2):,} (τ{round(leaderboard_df[i].emission, 2):,})": leaderboard_df[
                        i
                    ].incentive
                    / class_denominator
                    for i in range(0, min(10, len(leaderboard_df)))
                    if leaderboard_df[i].incentive
                    and leaderboard_df[i].competition == competition.id
                }

                gr.Label(
                    value=class_values,
                    num_top_classes=10,
                )

                with gr.Accordion("Evaluation Stats"):
                    gr.HTML(
                        EVALUATION_HEADER.replace(
                            "{date}",
                            last_refresh.strftime("refreshed at %H:%M on %Y-%m-%d"),
                        )
                    )
                    with gr.Tabs():
                        for entry in leaderboard_df:
                            if entry.competition == competition.id:
                                sample = scores[entry.uid]["sample"]
                                if sample is not None:
                                    name = f"{entry.namespace}/{entry.name} ({entry.commit[0:8]}, UID={entry.uid})"
                                    with gr.Tab(name):
                                        gr.Chatbot([(sample[0], sample[1])])
                                        # gr.Chatbot([(sample[0], f"*{name}*: {sample[1]}"), (None, f"*GPT-4*: {sample[2]}")])

                    show_stale = gr.Checkbox(label="Show Stale", interactive=True)
                    leaderboard_table = gr.components.Dataframe(
                        value=leaderboard_data(
                            show_stale.value, scores, competition.id
                        ),
                        headers=[
                            "Name",
                            "Win Rate",
                            "Weight",
                            "UID",
                            "Block",
                        ],
                        datatype=[
                            "markdown",
                            "number",
                            "number",
                            "number",
                            "number",
                        ],
                        elem_id="leaderboard-table",
                        interactive=False,
                        visible=True,
                    )
                    gr.HTML(EVALUATION_DETAILS)
                    show_stale.change(
                        lambda x: leaderboard_data(x, scores, competition.id),
                        [show_stale],
                        leaderboard_table,
                    )

    with gr.Accordion("Validator Stats"):
        validator_table = gr.components.Dataframe(
            value=[
                [uid, int(validator_df[uid][1]), round(validator_df[uid][0], 4)]
                + [
                    validator_df[uid][-1].get(c.uid)
                    for c in leaderboard_df
                    if c.incentive
                ]
                for uid, _ in sorted(
                    zip(
                        validator_df.keys(),
                        [validator_df[x][1] for x in validator_df.keys()],
                    ),
                    key=lambda x: x[1],
                    reverse=True,
                )
            ],
            headers=["UID", "Stake (Ï„)", "V-Trust"]
            + [
                f"{c.namespace}/{c.name} ({c.commit[0:8]}, UID={c.uid})"
                for c in leaderboard_df
                if c.incentive
            ],
            datatype=["number", "number", "number"]
            + ["number" for c in leaderboard_df if c.incentive],
            interactive=False,
            visible=True,
        )


def restart_space():
    API.restart_space(repo_id=REPO_ID, token=H4_TOKEN)

# # Switch to independent restarter now
# scheduler = BackgroundScheduler()
# scheduler.add_job(restart_space, "interval", seconds=60 * 5)  # restart every 15 minutes
# scheduler.start()

demo.launch()