Anonumous's picture
Update benchmark code
3b7d44a
raw
history blame
5.34 kB
import json
import pandas as pd
from statistics import mean
from huggingface_hub import HfApi, create_repo
from datasets import load_dataset, Dataset
from datasets.data_files import EmptyDatasetError
from constants import (
REPO_ID,
HF_TOKEN,
DATASETS,
SHORT_DATASET_NAMES,
DATASET_DESCRIPTIONS,
)
api = HfApi(token=HF_TOKEN)
def init_repo():
try:
api.repo_info(REPO_ID, repo_type="dataset")
except:
create_repo(REPO_ID, repo_type="dataset", private=True, token=HF_TOKEN)
def load_data():
columns = (
["model_name", "link", "license", "overall_wer", "overall_cer"]
+ [f"wer_{ds}" for ds in DATASETS]
+ [f"cer_{ds}" for ds in DATASETS]
)
try:
dataset = load_dataset(REPO_ID, token=HF_TOKEN)
df = dataset["train"].to_pandas()
except EmptyDatasetError:
df = pd.DataFrame(columns=columns)
if not df.empty:
df = df.sort_values("overall_wer").reset_index(drop=True)
df.insert(0, "rank", df.index + 1)
df["overall_wer"] = (df["overall_wer"] * 100).round(2).apply(lambda x: f"{x}")
df["overall_cer"] = (df["overall_cer"] * 100).round(2).apply(lambda x: f"{x}")
for ds in DATASETS:
df[f"wer_{ds}"] = (df[f"wer_{ds}"] * 100).round(2)
df[f"cer_{ds}"] = (df[f"cer_{ds}"] * 100).round(2)
for short_ds, ds in zip(SHORT_DATASET_NAMES, DATASETS):
df[short_ds] = df.apply(
lambda row: f'<span title="CER: {row[f"cer_{ds}"]:.2f}" style="cursor: help;">{row[f"wer_{ds}"]:.2f}</span>',
axis=1,
)
df = df.drop(columns=[f"wer_{ds}", f"cer_{ds}"])
df["model_name"] = df.apply(
lambda row: f'<a href="{row["link"]}" target="_blank">{row["model_name"]}</a>',
axis=1,
)
df = df.drop(columns=["link"])
df["license"] = df["license"].apply(
lambda x: "Открытая"
if any(
term in x.lower() for term in ["mit", "apache", "bsd", "gpl", "open"]
)
else "Закрытая"
)
df.rename(
columns={
"overall_wer": "Средний WER ⬇️",
"overall_cer": "Средний CER ⬇️",
"license": "Тип модели",
"model_name": "Модель",
"rank": "Ранг",
},
inplace=True,
)
table_html = df.to_html(escape=False, index=False)
return f'<div class="leaderboard-wrapper"><div class="leaderboard-table">{table_html}</div></div>'
else:
return (
'<div class="leaderboard-wrapper"><div class="leaderboard-table"><table><thead><tr><th>Ранг</th><th>Модель</th><th>Тип модели</th><th>Средний WER ⬇️</th><th>Средний CER ⬇️</th>'
+ "".join(f"<th>{short}</th>" for short in SHORT_DATASET_NAMES)
+ "</tr></thead><tbody></tbody></table></div></div>"
)
def process_submit(json_str):
columns = (
["model_name", "link", "license", "overall_wer", "overall_cer"]
+ [f"wer_{ds}" for ds in DATASETS]
+ [f"cer_{ds}" for ds in DATASETS]
)
try:
data = json.loads(json_str)
required_keys = ["model_name", "link", "license", "metrics"]
if not all(key in data for key in required_keys):
raise ValueError(
"Неверная структура JSON. Требуемые поля: model_name, link, license, metrics"
)
metrics = data["metrics"]
if set(metrics.keys()) != set(DATASETS):
raise ValueError(
f"Метрики должны быть для всех датасетов: {', '.join(DATASETS)}"
)
wers = []
cers = []
row = {
"model_name": data["model_name"],
"link": data["link"],
"license": data["license"],
}
for ds in DATASETS:
if "wer" not in metrics[ds] or "cer" not in metrics[ds]:
raise ValueError(f"Для {ds} требуются wer и cer")
row[f"wer_{ds}"] = metrics[ds]["wer"]
row[f"cer_{ds}"] = metrics[ds]["cer"]
wers.append(metrics[ds]["wer"])
cers.append(metrics[ds]["cer"])
row["overall_wer"] = mean(wers)
row["overall_cer"] = mean(cers)
try:
dataset = load_dataset(REPO_ID, token=HF_TOKEN)
df = dataset["train"].to_pandas()
except EmptyDatasetError:
df = pd.DataFrame(columns=columns)
new_df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
new_dataset = Dataset.from_pandas(new_df)
new_dataset.push_to_hub(REPO_ID, token=HF_TOKEN)
updated_html = load_data()
return updated_html, "Успешно добавлено!"
except Exception as e:
return None, f"Ошибка: {str(e)}"
def get_datasets_description():
desc = "# Описание датасетов\n\n"
for short_ds, info in DATASET_DESCRIPTIONS.items():
desc += f"### {short_ds} ({info['full_name']})\n{info['description']}\n- Количество записей: {info['num_rows']}\n\n"
return desc