Spaces:
Runtime error
Runtime error
File size: 10,446 Bytes
b165a4b 74e3b17 0811d37 74e3b17 de52ad3 74e3b17 400662c 74e3b17 c67a861 5174522 95c19d6 c67a861 2a5f9fb c67a861 b165a4b 2a73469 69cf5b3 c67a861 0660028 69cf5b3 0ef2585 69cf5b3 3922a8b 103ee13 3922a8b 103ee13 3922a8b 69cf5b3 3922a8b 69cf5b3 3922a8b 69cf5b3 103ee13 69cf5b3 3922a8b 103ee13 69cf5b3 74c08c9 69cf5b3 74e3b17 6d58c89 400662c 6d58c89 74e3b17 0660028 c67a861 74e3b17 69cf5b3 6d58c89 69cf5b3 3922a8b 74e3b17 3922a8b 69cf5b3 3922a8b 69cf5b3 041b899 6d58c89 69cf5b3 95c19d6 0660028 c67a861 0660028 a3eda6f c67a861 a3eda6f 0660028 400662c 0660028 400662c 0660028 400662c 0660028 48ae20c 02fb3fc 0660028 48ae20c 54bed10 4f47d86 54bed10 c67a861 a3eda6f 48ae20c 0660028 a3eda6f 0660028 a3eda6f 0660028 a3eda6f 0660028 48ae20c 263af70 a3eda6f 74e3b17 69cf5b3 0660028 69cf5b3 0660028 0ef2585 400662c c67a861 69cf5b3 48ae20c 6d58c89 48ae20c 0660028 48ae20c 0660028 48ae20c 0ef2585 0660028 74e3b17 0660028 400662c 263af70 8071283 263af70 95c19d6 a3eda6f 8c49cb6 74e3b17 0660028 74e3b17 95c19d6 0811d37 69cf5b3 |
|
import logging
import os
import gradio as gr
import numpy as np
import pandas as pd
import scipy.stats
from apscheduler.schedulers.background import BackgroundScheduler
from datasets import load_dataset
from huggingface_hub import HfApi
# Set up logging
logger = logging.getLogger("app")
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
# Disable the absl logger (annoying)
logging.getLogger("absl").setLevel(logging.WARNING)
API = HfApi(token=os.environ.get("TOKEN"))
RESULTS_REPO = "open-rl-leaderboard/results_v2"
REFRESH_RATE = 5 * 60 # 5 minutes
ALL_ENV_IDS = {
"Atari": [
"AdventureNoFrameskip-v4",
"AirRaidNoFrameskip-v4",
"AlienNoFrameskip-v4",
"AmidarNoFrameskip-v4",
"AssaultNoFrameskip-v4",
"AsterixNoFrameskip-v4",
"AsteroidsNoFrameskip-v4",
"AtlantisNoFrameskip-v4",
"BankHeistNoFrameskip-v4",
"BattleZoneNoFrameskip-v4",
"BeamRiderNoFrameskip-v4",
"BerzerkNoFrameskip-v4",
"BowlingNoFrameskip-v4",
"BoxingNoFrameskip-v4",
"BreakoutNoFrameskip-v4",
"CarnivalNoFrameskip-v4",
"CentipedeNoFrameskip-v4",
"ChopperCommandNoFrameskip-v4",
"CrazyClimberNoFrameskip-v4",
"DefenderNoFrameskip-v4",
"DemonAttackNoFrameskip-v4",
"DoubleDunkNoFrameskip-v4",
"ElevatorActionNoFrameskip-v4",
"EnduroNoFrameskip-v4",
"FishingDerbyNoFrameskip-v4",
"FreewayNoFrameskip-v4",
"FrostbiteNoFrameskip-v4",
"GopherNoFrameskip-v4",
"GravitarNoFrameskip-v4",
"HeroNoFrameskip-v4",
"IceHockeyNoFrameskip-v4",
"JamesbondNoFrameskip-v4",
"JourneyEscapeNoFrameskip-v4",
"KangarooNoFrameskip-v4",
"KrullNoFrameskip-v4",
"KungFuMasterNoFrameskip-v4",
"MontezumaRevengeNoFrameskip-v4",
"MsPacmanNoFrameskip-v4",
"NameThisGameNoFrameskip-v4",
"PhoenixNoFrameskip-v4",
"PitfallNoFrameskip-v4",
"PongNoFrameskip-v4",
"PooyanNoFrameskip-v4",
"PrivateEyeNoFrameskip-v4",
"QbertNoFrameskip-v4",
"RiverraidNoFrameskip-v4",
"RoadRunnerNoFrameskip-v4",
"RobotankNoFrameskip-v4",
"SeaquestNoFrameskip-v4",
"SkiingNoFrameskip-v4",
"SolarisNoFrameskip-v4",
"SpaceInvadersNoFrameskip-v4",
"StarGunnerNoFrameskip-v4",
"TennisNoFrameskip-v4",
"TimePilotNoFrameskip-v4",
"TutankhamNoFrameskip-v4",
"UpNDownNoFrameskip-v4",
"VentureNoFrameskip-v4",
"VideoPinballNoFrameskip-v4",
"WizardOfWorNoFrameskip-v4",
"YarsRevengeNoFrameskip-v4",
"ZaxxonNoFrameskip-v4",
],
"Box2D": [
"BipedalWalker-v3",
"BipedalWalkerHardcore-v3",
"CarRacing-v2",
"LunarLander-v2",
"LunarLanderContinuous-v2",
],
"Toy text": [
"Blackjack-v1",
"CliffWalking-v0",
"FrozenLake-v1",
"FrozenLake8x8-v1",
],
"Classic control": [
"Acrobot-v1",
"CartPole-v1",
"MountainCar-v0",
"MountainCarContinuous-v0",
"Pendulum-v1",
],
"MuJoCo": [
"Ant-v4",
"HalfCheetah-v4",
"Hopper-v4",
"Humanoid-v4",
"HumanoidStandup-v4",
"InvertedDoublePendulum-v4",
"InvertedPendulum-v4",
"Pusher-v4",
"Reacher-v4",
"Swimmer-v4",
"Walker2d-v4",
],
"PyBullet": [
"AntBulletEnv-v0",
"HalfCheetahBulletEnv-v0",
"HopperBulletEnv-v0",
"HumanoidBulletEnv-v0",
"InvertedDoublePendulumBulletEnv-v0",
"InvertedPendulumSwingupBulletEnv-v0",
"MinitaurBulletEnv-v0",
"ReacherBulletEnv-v0",
"Walker2DBulletEnv-v0",
],
}
def iqm(x):
return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)
def get_leaderboard_df():
logger.info("Downloading results")
dataset = load_dataset(RESULTS_REPO, split="train") # split is not important, but we need to use "train")
df = dataset.to_pandas() # convert to pandas dataframe
df = df[df["status"] == "DONE"] # keep only the models that are done
df["iqm_episodic_return"] = df["episodic_returns"].apply(iqm)
logger.debug("Results downloaded")
return df
def select_env(df: pd.DataFrame, env_id: str):
df = df[df["env_id"] == env_id]
df = df.sort_values("iqm_episodic_return", ascending=False)
df["ranking"] = np.arange(1, len(df) + 1)
return df
def format_df(df: pd.DataFrame):
# Add hyperlinks
df = df.copy()
for index, row in df.iterrows():
user_id = row["user_id"]
model_id = row["model_id"]
df.loc[index, "user_id"] = f"[{user_id}](https://huggingface.co/{user_id})"
df.loc[index, "model_id"] = f"[{model_id}](https://huggingface.co/{user_id}/{model_id})"
# Keep only the relevant columns
df = df[["ranking", "user_id", "model_id", "iqm_episodic_return"]]
return df.values.tolist()
def refresh_video(df, env_id):
env_df = select_env(df, env_id)
if not env_df.empty:
user_id = env_df.iloc[0]["user_id"]
model_id = env_df.iloc[0]["model_id"]
sha = env_df.iloc[0]["sha"]
repo_id = f"{user_id}/{model_id}"
try:
video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=sha, repo_type="model")
return video_path
except Exception as e:
logger.error(f"Error while downloading video for {env_id}: {e}")
return None
else:
return None
def refresh_one_video(df, env_id):
def inner():
return refresh_video(df, env_id)
return inner
def refresh_winner(df, env_id):
# print("Refreshing winners")
env_df = select_env(df, env_id)
if not env_df.empty:
user_id = env_df.iloc[0]["user_id"]
model_id = env_df.iloc[0]["model_id"]
url = f"https://huggingface.co/{user_id}/{model_id}"
return f"""## {env_id}
### π [Best model]({url}) π"""
else:
return f"""## {env_id}
This leaderboard is quite empty... π’
Be the first to submit your model!
Check the tab "π Getting my agent evaluated"
"""
def refresh_num_models(df):
return f"The leaderboard currently contains {len(df):,} models."
css = """
.generating {
border: none;
}
h2 {
text-align: center;
}
h3 {
text-align: center;
}
"""
def update_globals():
global dataframes, winner_texts, video_pathes, num_models_str, df
df = get_leaderboard_df()
all_env_ids = [env_id for env_ids in ALL_ENV_IDS.values() for env_id in env_ids]
dataframes = {env_id: format_df(select_env(df, env_id)) for env_id in all_env_ids}
winner_texts = {env_id: refresh_winner(df, env_id) for env_id in all_env_ids}
video_pathes = {env_id: refresh_video(df, env_id) for env_id in all_env_ids}
num_models_str = refresh_num_models(df)
update_globals()
def refresh():
global dataframes, winner_texts, num_models_str
return list(dataframes.values()) + list(winner_texts.values()) + [num_models_str]
with gr.Blocks(css=css) as demo:
with open("texts/heading.md") as fp:
gr.Markdown(fp.read())
num_models_md = gr.Markdown()
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("π
Leaderboard"):
all_gr_dfs = {}
all_gr_winners = {}
all_gr_videos = {}
for env_domain, env_ids in ALL_ENV_IDS.items():
with gr.TabItem(env_domain):
for env_id in env_ids:
# If the env_id envs with "NoFrameskip-v4", we remove it to improve readability
tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
with gr.TabItem(tab_env_id) as tab:
logger.debug(f"Creating tab for {env_id}")
with gr.Row(equal_height=False):
with gr.Column(scale=3):
gr_df = gr.components.Dataframe(
headers=["π", "π§ User", "π€ Model id", "π IQM episodic return"],
datatype=["number", "markdown", "markdown", "number"],
)
with gr.Column(scale=1):
with gr.Row(): # Display the env_id and the winner
gr_winner = gr.Markdown()
with gr.Row(): # Play the video of the best model
gr_video = gr.PlayableVideo( # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689,
min_width=50,
show_download_button=False,
show_share_button=False,
show_label=False,
interactive=False,
)
all_gr_dfs[env_id] = gr_df
all_gr_winners[env_id] = gr_winner
all_gr_videos[env_id] = gr_video
tab.select(refresh_one_video(df, env_id), outputs=[gr_video])
# Load the first video of the first environment
demo.load(refresh_one_video(df, env_ids[0]), outputs=[all_gr_videos[env_ids[0]]])
with gr.TabItem("π Getting my agent evaluated"):
with open("texts/getting_my_agent_evaluated.md") as fp:
gr.Markdown(fp.read())
with gr.TabItem("π About"):
with open("texts/about.md") as fp:
gr.Markdown(fp.read())
demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()) + [num_models_md])
scheduler = BackgroundScheduler()
scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
scheduler.start()
if __name__ == "__main__":
demo.queue().launch()
|