|
import requests |
|
import pandas as pd |
|
from tqdm.auto import tqdm |
|
import streamlit as st |
|
from huggingface_hub import HfApi, hf_hub_download |
|
from huggingface_hub.repocard import metadata_load |
|
|
|
cer_langs = ["ja", "zh-CN", "zh-HK", "zh-TW"] |
|
|
|
|
|
def make_clickable(model_name): |
|
link = "https://huggingface.co/" + model_name |
|
return f'<a target="_blank" href="{link}">{model_name}</a>' |
|
|
|
|
|
def get_model_ids(): |
|
api = HfApi() |
|
models = api.list_models(filter="robust-speech-event") |
|
model_ids = [x.modelId for x in models] |
|
return model_ids |
|
|
|
|
|
def get_metadata(model_id): |
|
try: |
|
readme_path = hf_hub_download(model_id, filename="README.md") |
|
return metadata_load(readme_path) |
|
except requests.exceptions.HTTPError: |
|
|
|
return None |
|
|
|
|
|
def parse_metric_value(value): |
|
if isinstance(value, str): |
|
"".join(value.split("%")) |
|
try: |
|
value = float(value) |
|
except: |
|
value = None |
|
elif isinstance(value, float) and value < 1.1: |
|
|
|
value = 100 * value |
|
elif isinstance(value, list): |
|
if len(value) > 0: |
|
value = value[0] |
|
else: |
|
value = None |
|
value = round(value, 2) if value is not None else None |
|
return value |
|
|
|
|
|
def parse_metrics_rows(meta): |
|
if "model-index" not in meta or "language" not in meta: |
|
return None |
|
lang = meta["language"] |
|
lang = lang[0] if isinstance(lang, list) else lang |
|
for result in meta["model-index"][0]["results"]: |
|
if "dataset" not in result or "metrics" not in result: |
|
continue |
|
dataset = result["dataset"]["type"] |
|
if "args" not in result["dataset"]: |
|
continue |
|
dataset_config = result["dataset"]["args"] |
|
row = {"dataset": dataset, "lang": lang} |
|
for metric in result["metrics"]: |
|
type = metric["type"].lower().strip() |
|
if type not in ["wer", "cer"]: |
|
continue |
|
value = parse_metric_value(metric["value"]) |
|
if value is None: |
|
continue |
|
if type not in row or value < row[type]: |
|
|
|
row[type] = value |
|
if "wer" in row or "cer" in row: |
|
yield row |
|
|
|
|
|
@st.cache(ttl=600) |
|
def get_data(): |
|
data = [] |
|
model_ids = get_model_ids() |
|
for model_id in tqdm(model_ids): |
|
meta = get_metadata(model_id) |
|
if meta is None: |
|
continue |
|
for row in parse_metrics_rows(meta): |
|
if row is None: |
|
continue |
|
row["model_id"] = model_id |
|
data.append(row) |
|
return pd.DataFrame.from_records(data) |
|
|
|
|
|
dataframe = get_data() |
|
dataframe = dataframe.fillna("") |
|
dataframe["model_id"] = dataframe["model_id"].apply(make_clickable) |
|
|
|
_, col_center = st.columns([3, 6]) |
|
with col_center: |
|
st.image("logo.png", width=200) |
|
st.markdown("# Speech Models Leaderboard") |
|
|
|
lang = st.selectbox( |
|
"Language", |
|
sorted(dataframe["lang"].unique()), |
|
index=0, |
|
) |
|
lang_df = dataframe[dataframe.lang == lang] |
|
|
|
dataset = st.selectbox( |
|
"Dataset", |
|
sorted(lang_df["dataset"].unique()), |
|
index=0, |
|
) |
|
dataset_df = lang_df[lang_df.dataset == dataset] |
|
if lang in cer_langs: |
|
dataset_df = dataset_df[["model_id", "cer"]] |
|
dataset_df.sort_values("cer", inplace=True) |
|
else: |
|
dataset_df = dataset_df[["model_id", "wer"]] |
|
dataset_df.sort_values("wer", inplace=True) |
|
dataset_df.rename( |
|
columns={ |
|
"model_id": "Model", |
|
"wer": "WER (lower is better)", |
|
"cer": "CER (lower is better)", |
|
}, |
|
inplace=True, |
|
) |
|
|
|
st.write(dataset_df.to_html(escape=False, index=None), unsafe_allow_html=True) |
|
|