import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.repocard import metadata_load

import pandas as pd

import requests

from utils import *

api = HfApi()

def get_user_models(hf_username, env_tag, lib_tag):
    """
    List the Reinforcement Learning models
    from user given environment and lib
    :param hf_username: User HF username
    :param env_tag: Environment tag
    :param lib_tag: Library tag
    """
    api = HfApi()
    models = api.list_models(author=hf_username, filter=["reinforcement-learning", env_tag, lib_tag])

    user_model_ids = [x.modelId for x in models]
    return user_model_ids


def get_user_sf_models(hf_username, env_tag, lib_tag):
    api = HfApi()
    models_sf = []
    models = api.list_models(author=hf_username, filter=["reinforcement-learning", lib_tag])

    user_model_ids = [x.modelId for x in models]

    for model in user_model_ids:
        meta = get_metadata(model)
        if meta is None:
            continue
        result = meta["model-index"][0]["results"][0]["dataset"]["name"]
        if result == env_tag:
            models_sf.append(model)
            
    return models_sf


def get_metadata(model_id):
  """
  Get model metadata (contains evaluation data)
  :param model_id
  """
  try:
    readme_path = hf_hub_download(model_id, filename="README.md")
    return metadata_load(readme_path)
  except requests.exceptions.HTTPError:
    # 404 README.md not found
    return None


def parse_metrics_accuracy(meta):
  """
  Get model results and parse it
  :param meta: model metadata
  """
  if "model-index" not in meta:
    return None
  result = meta["model-index"][0]["results"]
  metrics = result[0]["metrics"]
  accuracy = metrics[0]["value"]
  
  return accuracy


def parse_rewards(accuracy):
  """
  Parse mean_reward and std_reward
  :param accuracy: model results
  """
  default_std = -1000
  default_reward= -1000
  if accuracy !=  None:
      accuracy = str(accuracy)
      parsed =  accuracy.split(' +/- ')
      if len(parsed)>1:
          mean_reward = float(parsed[0])
          std_reward =  float(parsed[1])
      elif len(parsed)==1: #only mean reward   
          mean_reward = float(parsed[0])
          std_reward =  float(0)
      else: 
          mean_reward = float(default_std)
          std_reward = float(default_reward)
  else:
      mean_reward = float(default_std)
      std_reward = float(default_reward)
  
  return mean_reward, std_reward

def calculate_best_result(user_model_ids):
  """
  Calculate the best results of a unit
  best_result = mean_reward - std_reward
  :param user_model_ids: RL models of a user
  """
  best_result = -1000
  best_model_id = ""
  for model in user_model_ids:
    meta = get_metadata(model)
    if meta is None:
      continue
    accuracy = parse_metrics_accuracy(meta)
    mean_reward, std_reward = parse_rewards(accuracy)
    result = mean_reward - std_reward
    if result > best_result:
      best_result = result
      best_model_id = model
      
  return best_result, best_model_id

def check_if_passed(model):
  """
  Check if result >= baseline
  to know if you pass
  :param model: user model
  """
  if model["best_result"] >= model["min_result"]:
    model["passed_"] = True

def certification(hf_username):
  results_certification = [
      {
          "unit": "Unit 1",
          "env": "LunarLander-v2",
          "library": "stable-baselines3",
          "min_result": 200,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
      },
  {
          "unit": "Unit 2",
          "env": "Taxi-v3",
          "library": "q-learning",
          "min_result": 4,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
  },
  {
          "unit": "Unit 3",
          "env": "SpaceInvadersNoFrameskip-v4",
          "library": "stable-baselines3",
          "min_result": 200,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
  },
  {
          "unit": "Unit 4",
          "env": "CartPole-v1",
          "library": "reinforce",
          "min_result": 350,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
  },
    {
          "unit": "Unit 4",
          "env": "Pixelcopter-PLE-v0",
          "library": "reinforce",
          "min_result": 5,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
    },
      {
          "unit": "Unit 5",
          "env": "ML-Agents-SnowballTarget",
          "library": "ml-agents",
          "min_result": -100,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
    },
      {
          "unit": "Unit 5",
          "env": "ML-Agents-Pyramids",
          "library": "ml-agents",
          "min_result": -100,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
    },
      {
          "unit": "Unit 6",
          "env": "PandaReachDense",
          "library": "stable-baselines3",
          "min_result": -3.5,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
    },
      {
          "unit": "Unit 7",
          "env": "ML-Agents-SoccerTwos",
          "library": "ml-agents",
          "min_result": -100,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
    },
      {
          "unit": "Unit 8 PI",
          "env": "LunarLander-v2",
          "library": "deep-rl-course",
          "min_result": -500,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
    },
      {
          "unit": "Unit 8 PII",
          "env": "doom_health_gathering_supreme",
          "library": "sample-factory",
          "min_result": 5,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
    },
  ] 
    
  for unit in results_certification:
    if unit["unit"] == "Unit 6":
      # Since Unit 6 can use PandaReachDense-v2 or v3
      user_models = get_user_models(hf_username, "PandaReachDense-v3", unit["library"])
      if len(user_models) == 0:
        print("Empty")
        user_models = get_user_models(hf_username, "PandaReachDense-v2", unit["library"])
    elif unit["unit"] != "Unit 8 PII":
      # Get user model
      user_models = get_user_models(hf_username, unit['env'], unit['library'])
      # For sample factory vizdoom we don't have env tag for now
    else: 
      user_models = get_user_sf_models(hf_username, unit['env'], unit['library'])
    
    # Calculate the best result and get the best_model_id
    best_result, best_model_id = calculate_best_result(user_models)

    # Save best_result and best_model_id
    unit["best_result"] = best_result
    unit["best_model_id"] = make_clickable_model(best_model_id)

    # Based on best_result do we pass the unit?
    check_if_passed(unit)
    unit["passed"] = pass_emoji(unit["passed_"])
    
  print(results_certification)
 
  df = pd.DataFrame(results_certification)
  df = df[['passed', 'unit', 'env', 'min_result', 'best_result', 'best_model_id']]
  return df


with gr.Blocks() as demo:
    gr.Markdown("""
    # 🏆 Progress in the Deep Reinforcement Learning Course 🏆
    (From Thomassimonini)
    
    """)
    # Hidden Textbox to pass the username
    hf_username = gr.Textbox(value="yesbut", visible=False)
    output = gr.components.Dataframe(value= certification(hf_username), headers=["Pass?", "Unit", "Environment", "Baseline", "Your best result", "Your best model id"], datatype=["markdown", "markdown", "markdown", "number", "number", "markdown", "bool"])
    
    # Automatically load progress
    demo.load(fn=certification, inputs=[hf_username], outputs=output)


demo.launch()