|
import gradio as gr |
|
import pandas as pd |
|
import json |
|
from constants import BANNER, INTRODUCTION_TEXT, CITATION_TEXT, METRICS_TAB_TEXT, DIR_OUTPUT_REQUESTS |
|
from init import is_model_on_hub, upload_file, load_all_info_from_dataset_hub |
|
from utils_display import AutoEvalColumn, fields, make_clickable_model, styled_error, styled_message |
|
from datetime import datetime, timezone |
|
from utils_display import make_best_bold |
|
import plotly.graph_objects as go |
|
|
|
LAST_UPDATED = "Sep 11th 2024" |
|
|
|
column_names = { |
|
"MODEL": "Model", |
|
"Avg. WER": "Average WER ⬇️ ", |
|
"Avg. RTFx": "RTFx ⬆️ ", |
|
"AMI WER": "AMI", |
|
"Earnings22 WER": "Earnings22", |
|
"Gigaspeech WER": "Gigaspeech", |
|
"LS Clean WER": "LS Clean", |
|
"LS Other WER": "LS Other", |
|
"SPGISpeech WER": "SPGISpeech", |
|
} |
|
|
|
original_df = pd.read_csv("data.csv") |
|
requested_models = [] |
|
|
|
|
|
def formatter(x): |
|
if type(x) is str: |
|
x = x |
|
else: |
|
x = round(x, 2) |
|
return x |
|
|
|
def format_df(df:pd.DataFrame): |
|
for col in df.columns: |
|
if col == "model": |
|
df[col] = df[col].apply(lambda x: x.replace(x, make_clickable_model(x))) |
|
else: |
|
df[col] = make_best_bold(df[col], col) |
|
return df |
|
|
|
original_df = format_df(original_df) |
|
original_df.rename(columns=column_names, inplace=True) |
|
original_df.sort_values(by='Average WER ⬇️ ', inplace=True) |
|
|
|
COLS = [c.name for c in fields(AutoEvalColumn)] |
|
TYPES = [c.type for c in fields(AutoEvalColumn)] |
|
|
|
def request_model(model_text, chbcoco2017): |
|
|
|
pass |
|
|
|
def update_table(column_selection, search:str): |
|
original_df = pd.read_csv("data.csv") |
|
|
|
original_df = original_df[original_df['model'].str.contains(search, case=False, na=False)] |
|
|
|
if column_selection == "All Columns": |
|
new_df = original_df |
|
elif column_selection == "Main Metrics": |
|
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ "]] |
|
elif column_selection == "Narrated": |
|
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "LS Clean", "LS Other", "Gigaspeech"]] |
|
new_df["Average WER ⬇️ "] = new_df[["LS Clean", "LS Other", "Gigaspeech"]].mean(axis=1).round(2) |
|
elif column_selection == "Oratory": |
|
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "Tedlium", "SPGISpeech", "Earnings22"]] |
|
new_df["Average WER ⬇️ "] = new_df[["Tedlium", "SPGISpeech", "Earnings22"]].mean(axis=1).round(2) |
|
elif column_selection == "Spontaneous": |
|
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "Gigaspeech", "SPGISpeech", "Earnings22", "AMI"]] |
|
new_df["Average WER ⬇️ "] = new_df[["Gigaspeech", "SPGISpeech", "Earnings22", "AMI"]].mean(axis=1).round(2) |
|
|
|
|
|
new_df = new_df.sort_values(by='Average WER ⬇️ ', ascending=True) |
|
new_df = format_df(new_df) |
|
|
|
return new_df |
|
|
|
def generate_plot(): |
|
df = pd.read_csv("data.csv") |
|
fig = go.Figure() |
|
|
|
fig.add_trace(go.Scatter( |
|
x=df['Average WER ⬇️ '], |
|
y=df['RTFx ⬆️ '], |
|
mode='markers', |
|
text=df['model'], |
|
hovertemplate= |
|
'<b>%{text}</b><br>' + |
|
'Average WER: %{x:.2f}<br>' + |
|
'RTFx: %{y:.2f}<br>' + |
|
'<extra></extra>', |
|
marker=dict( |
|
size=10, |
|
|
|
colorscale='Viridis', |
|
|
|
|
|
) |
|
)) |
|
|
|
|
|
fig.update_layout( |
|
title='ASR Model Performance: Average WER vs RTFx', |
|
xaxis_title='Average WER (lower is better)', |
|
yaxis_title='RTFx (higher is better)', |
|
|
|
hovermode='closest' |
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(BANNER, elem_id="banner") |
|
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text") |
|
|
|
with gr.Tabs(elem_classes="tab-buttons") as tabs: |
|
with gr.TabItem("🏅 Leaderboard", elem_id="od-benchmark-tab-table", id=0): |
|
leaderboard_table = gr.components.Dataframe( |
|
value=original_df, |
|
datatype=TYPES, |
|
elem_id="leaderboard-table", |
|
interactive=False, |
|
visible=True, |
|
height=500, |
|
) |
|
with gr.Accordion("📌 Select a more detailed subset",open=False): |
|
column_radio = gr.Radio( |
|
["All Columns", "Main Metrics", "Narrated", "Oratory", "Spontaneous"], |
|
label="Categories", |
|
value="All Columns" |
|
) |
|
|
|
search_bar = gr.Textbox(label="Search models", placeholder="Enter model name...") |
|
|
|
column_radio.change(update_table, inputs=[column_radio, search_bar], outputs=[leaderboard_table]) |
|
search_bar.submit(update_table, inputs=[column_radio, search_bar], outputs=[leaderboard_table]) |
|
|
|
with gr.TabItem("📈 Metrics", elem_id="od-benchmark-tab-table", id=1): |
|
gr.Markdown(METRICS_TAB_TEXT, elem_classes="markdown-text") |
|
|
|
with gr.TabItem("✉️✨ Request a model here!", elem_id="od-benchmark-tab-table", id=2): |
|
with gr.Column(): |
|
gr.Markdown("# ✉️✨ Request results for a new model here!", elem_classes="markdown-text") |
|
with gr.Column(): |
|
gr.Markdown("Select a dataset:", elem_classes="markdown-text") |
|
with gr.Column(): |
|
model_name_textbox = gr.Textbox(label="Model name (user_name/model_name)") |
|
chb_coco2017 = gr.Checkbox(label="COCO validation 2017 dataset", visible=False, value=True, interactive=False) |
|
with gr.Column(): |
|
mdw_submission_result = gr.Markdown() |
|
btn_submitt = gr.Button(value="🚀 Request") |
|
btn_submitt.click(request_model, |
|
[model_name_textbox, chb_coco2017], |
|
mdw_submission_result) |
|
|
|
with gr.TabItem("📊 Plots", elem_id="od-benchmark-tab-table", id=3): |
|
|
|
plot = gr.Plot(generate_plot) |
|
|
|
gr.Markdown(f"Last updated on **{LAST_UPDATED}**", elem_classes="markdown-text") |
|
|
|
with gr.Row(): |
|
with gr.Accordion("📙 Citation", open=False): |
|
gr.Textbox( |
|
value=CITATION_TEXT, lines=7, |
|
label="Copy the BibTeX snippet to cite this source", |
|
elem_id="citation-button", |
|
show_copy_button=True, |
|
) |
|
|
|
demo.launch() |