Spaces:
Paused
Paused
| import os | |
| import gradio as gr | |
| import json | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from datasets import load_dataset | |
| from plotly.subplots import make_subplots | |
| CATEGORIES = ["task-solving", "math-reasoning", "general-instruction", "natural-question", "safety"] | |
| LANGS = ['en', 'vi', 'th', 'id', 'km', 'lo', 'ms', 'my', 'tl'] | |
| FORCE_DOWNLOAD = bool(int(os.environ.get("FORCE_DOWNLOAD", "0"))) | |
| HF_TOKEN = str(os.environ.get("HF_TOKEN", "")) | |
| DATA_SET_REPO_PATH = str(os.environ.get("DATA_SET_REPO_PATH", "")) | |
| PERFORMANCE_FILENAME = str(os.environ.get("PERFORMANCE_FILENAME", "gpt4_single_json.csv")) | |
| rename_map = { | |
| "seallm13b10L6k_a_5a1R1_seaall_sft4x_1_5a1_r2_0_dpo_8_40000s": "SeaLLM-13b", | |
| "polylm": "PolyLM-13b", | |
| "qwen": "Qwen-14b", | |
| "gpt-3.5-turbo": "GPT-3.5-turbo", | |
| } | |
| CATEGORIES = [ "task-solving", "math-reasoning", "general-instruction", "natural-question", "safety", ] | |
| CATEGORIES_NAMES = { | |
| "task-solving": 'Task-solving', | |
| "math-reasoning": 'Math', | |
| "general-instruction": 'General-instruction', | |
| "natural-question": 'NaturalQA', | |
| "safety": 'Safety', | |
| } | |
| # LANGS = ['en', 'vi', 'th', 'id', 'km', 'lo', 'ms', 'my', 'tl'] | |
| LANGS = ['en', 'vi', 'id', 'ms', 'tl', 'th', 'km', 'lo', 'my'] | |
| LANG_NAMES = { | |
| 'en': 'eng', | |
| 'vi': 'vie', | |
| 'th': 'tha', | |
| 'id': 'ind', | |
| 'km': 'khm', | |
| 'lo': 'lao', | |
| 'ms': 'msa', | |
| 'my': 'mya', | |
| 'tl': 'tgl', | |
| } | |
| MODEL_DFRAME = None | |
| def get_model_df(): | |
| # global MODEL_DFRAME | |
| # if isinstance(MODEL_DFRAME, pd.DataFrame): | |
| # print(f'Load cache data frame') | |
| # return MODEL_DFRAME | |
| from huggingface_hub import hf_hub_download | |
| assert DATA_SET_REPO_PATH != '' | |
| assert HF_TOKEN != '' | |
| repo_id = DATA_SET_REPO_PATH | |
| filename = PERFORMANCE_FILENAME | |
| # data_path = f"{DATA_SET_REPO_PATH}/{PERFORMANCE_FILENAME}" | |
| file_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| force_download=FORCE_DOWNLOAD, | |
| local_dir='./hf_cache', | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| print(f'Downloaded file at {file_path} from {DATA_SET_REPO_PATH} / {PERFORMANCE_FILENAME}') | |
| df = pd.read_csv(file_path) | |
| return df | |
| def aggregate_df(df, model_dict, category_name, categories): | |
| scores_all = [] | |
| all_models = df["model"].unique() | |
| for model in all_models: | |
| for i, cat in enumerate(categories): | |
| # filter category/model, and score format error (<1% case) | |
| res = df[(df[category_name]==cat) & (df["model"]==model) & (df["score"] >= 0)] | |
| score = res["score"].mean() | |
| cat_name = cat | |
| scores_all.append({"model": model, category_name: cat_name, "score": score}) | |
| target_models = list(model_dict.keys()) | |
| scores_target = [scores_all[i] for i in range(len(scores_all)) if scores_all[i]["model"] in target_models] | |
| scores_target = sorted(scores_target, key=lambda x: target_models.index(x["model"]), reverse=True) | |
| df_score = pd.DataFrame(scores_target) | |
| df_score = df_score[df_score["model"].isin(target_models)] | |
| rename_map = model_dict | |
| for k, v in rename_map.items(): | |
| df_score.replace(k, v, inplace=True) | |
| return df_score | |
| def polar_subplot(fig, dframe, model_names, category_label, category_names, row, col, showlegend=True): | |
| # cat category | |
| colors = px.colors.qualitative.Plotly | |
| for i, (model, model_name) in enumerate(model_names): | |
| cat_list = dframe[dframe['model'] == model_name][category_label].tolist() | |
| score_list = dframe[dframe['model'] == model_name]['score'].tolist() | |
| cat_list += [cat_list[0]] | |
| cat_list = [category_names[x] for x in cat_list] | |
| score_list += [score_list[0]] | |
| polar = go.Scatterpolar( | |
| name = model_name, | |
| r = score_list, | |
| theta = cat_list, | |
| legendgroup=f'{i}', | |
| marker=dict(color=colors[i]), | |
| hovertemplate="""Score: %{r:.2f}""", | |
| showlegend=showlegend, | |
| ) | |
| fig.add_trace(polar, row, col) | |
| def plot_agg_fn(): | |
| df = get_model_df() | |
| all_models = df["model"].unique() | |
| model_names = list(rename_map.items()) | |
| colors = px.colors.qualitative.Plotly | |
| cat_df = aggregate_df(df, rename_map, "category", CATEGORIES, ) | |
| lang_df = aggregate_df(df, rename_map, "lang", LANGS, ) | |
| fig = make_subplots( | |
| rows=1, cols=2, | |
| specs=[[{'type': 'polar'}]*2], | |
| subplot_titles=("By Category", "By Language"), | |
| ) | |
| fig.layout.annotations[0].y = 1.05 | |
| fig.layout.annotations[1].y = 1.05 | |
| # cat category | |
| for i, (model, model_name) in enumerate(model_names): | |
| cat_list = cat_df[cat_df['model'] == model_name]['category'].tolist() | |
| score_list = cat_df[cat_df['model'] == model_name]['score'].tolist() | |
| cat_list += [cat_list[0]] | |
| cat_list = [CATEGORIES_NAMES[x] for x in cat_list] | |
| score_list += [score_list[0]] | |
| polar = go.Scatterpolar( | |
| name = model_name, | |
| r = score_list, | |
| theta = cat_list, | |
| legendgroup=f'{i}', | |
| marker=dict(color=colors[i]), | |
| hovertemplate="""Score: %{r:.2f}""", | |
| ) | |
| fig.add_trace(polar, 1, 1) | |
| # cat langs | |
| for i, (model, model_name) in enumerate(model_names): | |
| cat_list = lang_df[lang_df['model'] == model_name]['lang'].tolist() | |
| score_list = lang_df[lang_df['model'] == model_name]['score'].tolist() | |
| cat_list += [cat_list[0]] | |
| score_list += [score_list[0]] | |
| cat_list = [LANG_NAMES[x] for x in cat_list] | |
| polar = go.Scatterpolar( | |
| name = model_name, | |
| r = score_list, | |
| theta = cat_list, | |
| legendgroup=f'{i}', | |
| marker=dict(color=colors[i]), | |
| hovertemplate="""Score: %{r:.2f}""", | |
| showlegend=False, | |
| ) | |
| fig.add_trace(polar, 1, 2) | |
| polar_config = dict( | |
| angularaxis = dict( | |
| rotation=90, # start position of angular axis | |
| ), | |
| radialaxis = dict( | |
| range=[0, 10], | |
| ), | |
| ) | |
| fig.update_layout( | |
| polar = polar_config, | |
| polar2 = polar_config, | |
| title='Sea-Bench (rated by GPT-4)', | |
| ) | |
| return fig | |
| def plot_by_lang_fn(): | |
| df = get_model_df() | |
| model_names = list(rename_map.items()) | |
| fig = make_subplots( | |
| rows=3, cols=3, | |
| specs=[[{'type': 'polar'}]*3] * 3, | |
| subplot_titles=list(LANG_NAMES.values()), | |
| # vertical_spacing=1 | |
| ) | |
| # print(fig.layout.annotations) | |
| for ano in fig.layout.annotations: | |
| ano.y = ano.y + 0.02 | |
| has_safety = ['vi', 'id', 'th'] | |
| for lang_id, lang in enumerate(LANGS): | |
| cat_names = CATEGORIES if lang in has_safety else [x for x in CATEGORIES if x != 'safety'] | |
| cat_lang_df = aggregate_df(df[df['lang'] == lang], rename_map, "category", cat_names, ) | |
| row = lang_id // 3 + 1 | |
| col = lang_id % 3 + 1 | |
| polar_subplot(fig, cat_lang_df, model_names, 'category', CATEGORIES_NAMES, row, col, showlegend=lang_id == 0) | |
| polar_config = dict( | |
| angularaxis = dict( | |
| rotation=90, # start position of angular axis | |
| ), | |
| radialaxis = dict( | |
| range=[0, 10], | |
| ), | |
| ) | |
| layer_kwargs = {f"polar{i}": polar_config for i in range(1, 10)} | |
| fig.update_layout( | |
| title='Sea-Bench - By language (rated by GPT-4)', | |
| height=1000, | |
| # width=1200, | |
| **layer_kwargs | |
| ) | |
| return fig | |
| def both_plot(): | |
| return plot_agg_fn(), plot_by_lang_fn() | |
| def attach_plot_to_demo(demo: gr.Blocks): | |
| with gr.Accordion("Psst... wanna see some performance benchmarks?", open=False) as accord: | |
| # gr_plot_agg = gr.Plot(plot_agg_fn, label="Aggregated") | |
| # gr_plot_bylang = gr.Plot(plot_by_lang_fn, label='By language') | |
| show_result = gr.Button("Load benchmark results") | |
| gr_plot_agg = gr.Plot(label="Aggregated") | |
| gr_plot_bylang = gr.Plot(label='By language') | |
| # def callback(): | |
| # demo.load(plot_agg_fn, [], gr_plot_agg) | |
| # demo.load(plot_by_lang_fn, [], gr_plot_bylang) | |
| show_result.click(both_plot, [], [gr_plot_agg, gr_plot_bylang]) | |