Spaces:
Running
Running
Commit
·
73d9a01
1
Parent(s):
0b6543b
add gradio app
Browse files
app.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from datasets import disable_caching, load_dataset
|
| 3 |
+
from transformer_ranker import TransformerRanker, prepare_popular_models
|
| 4 |
+
import traceback
|
| 5 |
+
|
| 6 |
+
from utils import (
|
| 7 |
+
DISABLED_BUTTON_VARIANT, ENABLED_BUTTON_VARIANT, CSS, HEADLINE, FOOTER,
|
| 8 |
+
EmbeddingProgressTracker, check_dataset_exists, check_dataset_is_loaded,
|
| 9 |
+
compute_ratio, ensure_one_lm_selected, get_dataset_info
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
disable_caching()
|
| 13 |
+
|
| 14 |
+
THEME = "pseudolab/huggingface-korea-theme"
|
| 15 |
+
DEFAULT_SAMPLES = 1000
|
| 16 |
+
MAX_SAMPLES = 5000
|
| 17 |
+
LANGUAGE_MODELS = prepare_popular_models('base') + prepare_popular_models('large')
|
| 18 |
+
|
| 19 |
+
# Add a tiny model for demonstration on CPU
|
| 20 |
+
LANGUAGE_MODELS = ['prajjwal1/bert-tiny'] + list(dict.fromkeys(LANGUAGE_MODELS))
|
| 21 |
+
LANGUAGE_MODELS.insert(LANGUAGE_MODELS.index("bert-base-cased") + 1, "bert-base-uncased")
|
| 22 |
+
|
| 23 |
+
# Preselect some small models
|
| 24 |
+
DEFAULT_MODELS = [
|
| 25 |
+
"prajjwal1/bert-tiny", "google/electra-small-discriminator",
|
| 26 |
+
"distilbert-base-cased", "sentence-transformers/all-MiniLM-L12-v2"
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
with gr.Blocks(css=CSS, theme=THEME) as demo:
|
| 31 |
+
|
| 32 |
+
########## STEP 1: Load the Dataset ##########
|
| 33 |
+
|
| 34 |
+
gr.Markdown(HEADLINE)
|
| 35 |
+
|
| 36 |
+
gr.Markdown("## Step 1: Load a Dataset")
|
| 37 |
+
with gr.Group():
|
| 38 |
+
dataset = gr.State(None)
|
| 39 |
+
|
| 40 |
+
dataset_name = gr.Textbox(
|
| 41 |
+
label="Enter the name of your dataset",
|
| 42 |
+
placeholder="Examples: trec, ag_news, sst2, conll2003, leondz/wnut_17",
|
| 43 |
+
max_lines=1,
|
| 44 |
+
)
|
| 45 |
+
select_dataset_button = gr.Button(
|
| 46 |
+
value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Activate the "Load dataset" button if dataset was found
|
| 50 |
+
dataset_name.change(
|
| 51 |
+
check_dataset_exists, inputs=dataset_name, outputs=select_dataset_button
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
gr.Markdown(
|
| 55 |
+
"*The number of samples that can be used in this demo is limited to save resources. "
|
| 56 |
+
"To run an estimate on the full dataset, check out the "
|
| 57 |
+
"[library](https://github.com/flairNLP/transformer-ranker).*"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
########## Step 1.1 Dataset preprocessing ##########
|
| 61 |
+
|
| 62 |
+
with gr.Accordion("Dataset settings", open=False) as dataset_config:
|
| 63 |
+
with gr.Row() as dataset_details:
|
| 64 |
+
dataset_name_label = gr.Label("", label="Dataset Name")
|
| 65 |
+
num_samples = gr.State(0)
|
| 66 |
+
num_samples_label = gr.Label("", label="Number of Samples")
|
| 67 |
+
num_samples.change(
|
| 68 |
+
lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
with gr.Row():
|
| 72 |
+
text_column = gr.Dropdown("", label="Text Column")
|
| 73 |
+
text_pair_column = gr.Dropdown("", label="Text Pair Column")
|
| 74 |
+
|
| 75 |
+
with gr.Row():
|
| 76 |
+
label_column = gr.Dropdown("", label="Label Column")
|
| 77 |
+
task_category = gr.Dropdown("", label="Task Type")
|
| 78 |
+
|
| 79 |
+
with gr.Group():
|
| 80 |
+
downsample_ratio = gr.State(0.0)
|
| 81 |
+
num_samples_to_use = gr.Slider(
|
| 82 |
+
20, MAX_SAMPLES, label="Samples to use", value=DEFAULT_SAMPLES, step=1
|
| 83 |
+
)
|
| 84 |
+
downsample_ratio_label = gr.Label("", label="Ratio of dataset to use")
|
| 85 |
+
downsample_ratio.change(
|
| 86 |
+
lambda x: f"{x:.1%}",
|
| 87 |
+
inputs=[downsample_ratio],
|
| 88 |
+
outputs=[downsample_ratio_label],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
num_samples_to_use.change(
|
| 92 |
+
compute_ratio,
|
| 93 |
+
inputs=[num_samples_to_use, num_samples],
|
| 94 |
+
outputs=downsample_ratio,
|
| 95 |
+
)
|
| 96 |
+
num_samples.change(
|
| 97 |
+
compute_ratio,
|
| 98 |
+
inputs=[num_samples_to_use, num_samples],
|
| 99 |
+
outputs=downsample_ratio,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Download the dataset and show details
|
| 103 |
+
def select_dataset(dataset_name):
|
| 104 |
+
try:
|
| 105 |
+
dataset = load_dataset(dataset_name, trust_remote_code=True)
|
| 106 |
+
dataset_info = get_dataset_info(dataset)
|
| 107 |
+
except ValueError:
|
| 108 |
+
gr.Warning("Dataset collections are not supported. Please use a single dataset.")
|
| 109 |
+
|
| 110 |
+
return (
|
| 111 |
+
gr.update(value="Loaded", interactive=False, variant=DISABLED_BUTTON_VARIANT),
|
| 112 |
+
gr.Accordion(open=True),
|
| 113 |
+
dataset_name,
|
| 114 |
+
dataset,
|
| 115 |
+
*dataset_info
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
select_dataset_button.click(
|
| 119 |
+
select_dataset,
|
| 120 |
+
inputs=[dataset_name],
|
| 121 |
+
outputs=[
|
| 122 |
+
select_dataset_button,
|
| 123 |
+
dataset_config,
|
| 124 |
+
dataset_name_label,
|
| 125 |
+
dataset,
|
| 126 |
+
task_category,
|
| 127 |
+
text_column,
|
| 128 |
+
text_pair_column,
|
| 129 |
+
label_column,
|
| 130 |
+
num_samples,
|
| 131 |
+
],
|
| 132 |
+
scroll_to_output=True,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
########## STEP 2 ##########
|
| 136 |
+
|
| 137 |
+
gr.Markdown("## Step 2: Select a List of Language Models")
|
| 138 |
+
with gr.Group():
|
| 139 |
+
model_options = [
|
| 140 |
+
(model_handle.split("/")[-1], model_handle)
|
| 141 |
+
for model_handle in LANGUAGE_MODELS
|
| 142 |
+
]
|
| 143 |
+
models = gr.CheckboxGroup(
|
| 144 |
+
choices=model_options, label="Select Models", value=DEFAULT_MODELS
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
########## STEP 3: Run Language Model Ranking ##########
|
| 148 |
+
|
| 149 |
+
gr.Markdown("## Step 3: Rank LMs")
|
| 150 |
+
|
| 151 |
+
with gr.Group():
|
| 152 |
+
with gr.Accordion("Advanced settings", open=False):
|
| 153 |
+
with gr.Row():
|
| 154 |
+
estimator = gr.Dropdown(
|
| 155 |
+
choices=["hscore", "logme", "knn"],
|
| 156 |
+
label="Transferability metric",
|
| 157 |
+
value="hscore",
|
| 158 |
+
)
|
| 159 |
+
layer_pooling_options = ["lastlayer", "layermean", "bestlayer"]
|
| 160 |
+
layer_pooling = gr.Dropdown(
|
| 161 |
+
choices=["lastlayer", "layermean", "bestlayer"],
|
| 162 |
+
label="Layer pooling",
|
| 163 |
+
value="layermean",
|
| 164 |
+
)
|
| 165 |
+
submit_button = gr.Button("Run Ranking", interactive=False, variant=DISABLED_BUTTON_VARIANT)
|
| 166 |
+
|
| 167 |
+
# Make button active if the dataset is loaded
|
| 168 |
+
dataset.change(
|
| 169 |
+
check_dataset_is_loaded,
|
| 170 |
+
inputs=[dataset, text_column, label_column, task_category],
|
| 171 |
+
outputs=submit_button
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
label_column.change(
|
| 175 |
+
check_dataset_is_loaded,
|
| 176 |
+
inputs=[dataset, text_column, label_column, task_category],
|
| 177 |
+
outputs=submit_button
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
text_column.change(
|
| 181 |
+
check_dataset_is_loaded,
|
| 182 |
+
inputs=[dataset, text_column, label_column, task_category],
|
| 183 |
+
outputs=submit_button
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def rank_models(
|
| 187 |
+
dataset,
|
| 188 |
+
downsample_ratio,
|
| 189 |
+
selected_models,
|
| 190 |
+
layer_pooling,
|
| 191 |
+
estimator,
|
| 192 |
+
text_column,
|
| 193 |
+
text_pair_column,
|
| 194 |
+
label_column,
|
| 195 |
+
task_category,
|
| 196 |
+
progress=gr.Progress(),
|
| 197 |
+
):
|
| 198 |
+
|
| 199 |
+
if text_column == "-":
|
| 200 |
+
raise gr.Error("Text column is not set.")
|
| 201 |
+
|
| 202 |
+
if label_column == "-":
|
| 203 |
+
raise gr.Error("Label column is not set.")
|
| 204 |
+
|
| 205 |
+
if task_category == "-":
|
| 206 |
+
raise gr.Error(
|
| 207 |
+
"Task category is not set. The dataset must support classification or regression tasks."
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
if text_pair_column == "-":
|
| 211 |
+
text_pair_column = None
|
| 212 |
+
|
| 213 |
+
progress(0.0, "Starting")
|
| 214 |
+
|
| 215 |
+
with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker:
|
| 216 |
+
try:
|
| 217 |
+
ranker = TransformerRanker(
|
| 218 |
+
dataset,
|
| 219 |
+
dataset_downsample=downsample_ratio,
|
| 220 |
+
text_column=text_column,
|
| 221 |
+
text_pair_column=text_pair_column,
|
| 222 |
+
label_column=label_column,
|
| 223 |
+
task_category=task_category,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
results = ranker.run(
|
| 227 |
+
models=selected_models,
|
| 228 |
+
layer_aggregator=layer_pooling,
|
| 229 |
+
estimator=estimator,
|
| 230 |
+
batch_size=64,
|
| 231 |
+
tracker=tracker,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
sorted_results = sorted(
|
| 235 |
+
results._results.items(), key=lambda item: item[1], reverse=True
|
| 236 |
+
)
|
| 237 |
+
return [
|
| 238 |
+
(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)
|
| 239 |
+
]
|
| 240 |
+
except Exception as e:
|
| 241 |
+
gr.Error("The dataset is not supported.")
|
| 242 |
+
|
| 243 |
+
gr.Markdown("## Results")
|
| 244 |
+
ranking_results = gr.Dataframe(
|
| 245 |
+
headers=["Rank", "Model", "Score"], datatype=["number", "str", "number"]
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
submit_button.click(
|
| 249 |
+
rank_models,
|
| 250 |
+
inputs=[
|
| 251 |
+
dataset,
|
| 252 |
+
downsample_ratio,
|
| 253 |
+
models,
|
| 254 |
+
layer_pooling,
|
| 255 |
+
estimator,
|
| 256 |
+
text_column,
|
| 257 |
+
text_pair_column,
|
| 258 |
+
label_column,
|
| 259 |
+
task_category,
|
| 260 |
+
],
|
| 261 |
+
outputs=ranking_results,
|
| 262 |
+
scroll_to_output=True,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
gr.Markdown(
|
| 266 |
+
"*The results are ranked by their transferability score, with the most suitable model listed first. "
|
| 267 |
+
"This ranking allows focusing on the higher-ranked models for further exploration and fine-tuning.*"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
gr.Markdown(FOOTER)
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
demo.queue(default_concurrency_limit=3)
|
| 274 |
+
demo.launch(max_threads=6)
|
utils.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from datasets import concatenate_datasets
|
| 3 |
+
from huggingface_hub import HfApi
|
| 4 |
+
from huggingface_hub.errors import HFValidationError
|
| 5 |
+
from requests.exceptions import HTTPError
|
| 6 |
+
from transformer_ranker import Result
|
| 7 |
+
from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
|
| 8 |
+
from transformer_ranker.embedder import Embedder
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
DISABLED_BUTTON_VARIANT = "huggingface"
|
| 12 |
+
ENABLED_BUTTON_VARIANT = "primary"
|
| 13 |
+
|
| 14 |
+
HEADLINE = """
|
| 15 |
+
<h1 align="center">TransformerRanker</h1>
|
| 16 |
+
<p align="center" style="max-width: 560px; margin: auto;">
|
| 17 |
+
A very simple library that helps you find the best-suited language model for your NLP task.
|
| 18 |
+
All you need to do is to select a dataset and a list of pre-trained language models (LMs) from the 🤗 HuggingFace Hub.
|
| 19 |
+
TransformerRanker will quickly estimate which of these LMs will perform best on the given dataset!
|
| 20 |
+
</p>
|
| 21 |
+
<p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
|
| 22 |
+
<a href="https://github.com/flairNLP/transformer-ranker">
|
| 23 |
+
<img src="https://img.shields.io/github/stars/flairNLP/transformer-ranker?style=social&label=Repository" alt="GitHub Badge">
|
| 24 |
+
</a>
|
| 25 |
+
<a href="https://pypi.org/project/transformer-ranker/">
|
| 26 |
+
<img src="https://img.shields.io/badge/Package-orange?style=flat&logo=python" alt="Package Badge">
|
| 27 |
+
</a>
|
| 28 |
+
<a href="https://github.com/flairNLP/transformer-ranker/blob/main/examples/01-walkthrough.md">
|
| 29 |
+
<img src="https://img.shields.io/badge/Tutorials-blue?style=flat&logo=readthedocs&logoColor=white" alt="Tutorials Badge">
|
| 30 |
+
</a>
|
| 31 |
+
<img src="https://img.shields.io/badge/license-MIT-green?style=flat" alt="License: MIT">
|
| 32 |
+
</p>
|
| 33 |
+
<p align="center">Developed at <a href="https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/">Humboldt University of Berlin</a>.</p>
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
FOOTER = """
|
| 37 |
+
**Note:** This demonstration currently runs on a CPU and is suited for smaller models only.
|
| 38 |
+
**Developers:** [@plonerma](https://huggingface.co/plonerma) and [@lukasgarbas](https://huggingface.co/lukasgarbas).
|
| 39 |
+
For feedback, suggestions, or contributions, reach out via GitHub or leave a message in the [discussions](https://huggingface.co/spaces/lukasgarbas/transformer-ranker/discussions).
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
CSS = """
|
| 43 |
+
.gradio-container{max-width: 800px !important}
|
| 44 |
+
a {color: #ff9d00;}
|
| 45 |
+
@media (prefers-color-scheme: dark) { a {color: #be185d;} }
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
hf_api = HfApi()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def check_dataset_exists(dataset_name):
|
| 53 |
+
"""Update loading button if dataset can be found"""
|
| 54 |
+
try:
|
| 55 |
+
hf_api.dataset_info(dataset_name)
|
| 56 |
+
return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT)
|
| 57 |
+
|
| 58 |
+
except (HTTPError, HFValidationError):
|
| 59 |
+
return gr.update(value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT)
|
| 60 |
+
|
| 61 |
+
def check_dataset_is_loaded(dataset, text_column, label_column, task_category):
|
| 62 |
+
if dataset and text_column != "-" and label_column != "-" and task_category != "-":
|
| 63 |
+
return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT)
|
| 64 |
+
else:
|
| 65 |
+
return gr.update(interactive=False, variant=DISABLED_BUTTON_VARIANT)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_dataset_info(dataset):
|
| 69 |
+
"""Show information for dataset settings"""
|
| 70 |
+
joined_dataset = concatenate_datasets(list(dataset.values()))
|
| 71 |
+
datacleaner = DatasetCleaner()
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
text_column = datacleaner._find_column(joined_dataset, "text column")
|
| 75 |
+
except ValueError:
|
| 76 |
+
gr.Warning("Text column can not be found. Select it in the dataset settings.")
|
| 77 |
+
text_column = "-"
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
label_column = datacleaner._find_column(joined_dataset, "label column")
|
| 81 |
+
except ValueError:
|
| 82 |
+
gr.Warning("Label column can not be found. Select it in the dataset settings.")
|
| 83 |
+
label_column = "-"
|
| 84 |
+
|
| 85 |
+
task_category = "-"
|
| 86 |
+
if label_column != "-":
|
| 87 |
+
try:
|
| 88 |
+
# Find or set the task_category
|
| 89 |
+
task_category = datacleaner._find_task_category(joined_dataset, label_column)
|
| 90 |
+
except ValueError:
|
| 91 |
+
gr.Warning(
|
| 92 |
+
"Task category could not be determined. The dataset must support classification or regression tasks.",
|
| 93 |
+
)
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
num_samples = len(joined_dataset)
|
| 97 |
+
|
| 98 |
+
return (
|
| 99 |
+
gr.update(
|
| 100 |
+
value=task_category,
|
| 101 |
+
choices=[str(t) for t in TaskCategory],
|
| 102 |
+
interactive=True,
|
| 103 |
+
),
|
| 104 |
+
gr.update(
|
| 105 |
+
value=text_column, choices=joined_dataset.column_names, interactive=True
|
| 106 |
+
),
|
| 107 |
+
gr.update(
|
| 108 |
+
value="-", choices=["-", *joined_dataset.column_names], interactive=True
|
| 109 |
+
),
|
| 110 |
+
gr.update(
|
| 111 |
+
value=label_column, choices=joined_dataset.column_names, interactive=True
|
| 112 |
+
),
|
| 113 |
+
num_samples,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def compute_ratio(num_samples_to_use, num_samples):
|
| 118 |
+
if num_samples > 0:
|
| 119 |
+
return num_samples_to_use / num_samples
|
| 120 |
+
else:
|
| 121 |
+
return 0.0
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def ensure_one_lm_selected(checkbox_values, previous_values):
|
| 125 |
+
if not any(checkbox_values):
|
| 126 |
+
return previous_values
|
| 127 |
+
return checkbox_values
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Apply monkey patch to enable callbacks
|
| 131 |
+
_old_embed = Embedder.embed
|
| 132 |
+
|
| 133 |
+
def _new_embed(embedder, sentences, batch_size: int = 32, **kw):
|
| 134 |
+
if embedder.tracker is not None:
|
| 135 |
+
embedder.tracker.update_num_batches(math.ceil(len(sentences) / batch_size))
|
| 136 |
+
|
| 137 |
+
return _old_embed(embedder, sentences, batch_size=batch_size, **kw)
|
| 138 |
+
|
| 139 |
+
Embedder.embed = _new_embed
|
| 140 |
+
|
| 141 |
+
_old_embed_batch = Embedder.embed_batch
|
| 142 |
+
|
| 143 |
+
def _new_embed_batch(embedder, *args, **kw):
|
| 144 |
+
r = _old_embed_batch(embedder, *args, **kw)
|
| 145 |
+
if embedder.tracker is not None:
|
| 146 |
+
embedder.tracker.update_batch_complete()
|
| 147 |
+
return r
|
| 148 |
+
|
| 149 |
+
Embedder.embed_batch = _new_embed_batch
|
| 150 |
+
|
| 151 |
+
_old_init = Embedder.__init__
|
| 152 |
+
|
| 153 |
+
def _new_init(embedder, *args, tracker=None, **kw):
|
| 154 |
+
_old_init(embedder, *args, **kw)
|
| 155 |
+
embedder.tracker = tracker
|
| 156 |
+
|
| 157 |
+
Embedder.__init__ = _new_init
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class EmbeddingProgressTracker:
|
| 161 |
+
def __init__(self, *, progress, model_names):
|
| 162 |
+
self.model_names = model_names
|
| 163 |
+
self.progress_bar = progress
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
def total(self):
|
| 167 |
+
return len(self.model_names)
|
| 168 |
+
|
| 169 |
+
def __enter__(self):
|
| 170 |
+
self.progress_bar = gr.Progress(track_tqdm=False)
|
| 171 |
+
self.current_model = -1
|
| 172 |
+
self.batches_complete = 0
|
| 173 |
+
self.batches_total = None
|
| 174 |
+
return self
|
| 175 |
+
|
| 176 |
+
def __exit__(self, typ, value, tb):
|
| 177 |
+
if typ is None:
|
| 178 |
+
self.progress_bar(1.0, desc="Done")
|
| 179 |
+
else:
|
| 180 |
+
self.progress_bar(1.0, desc="Error")
|
| 181 |
+
|
| 182 |
+
# Do not suppress any errors
|
| 183 |
+
return False
|
| 184 |
+
|
| 185 |
+
def update_num_batches(self, total):
|
| 186 |
+
self.current_model += 1
|
| 187 |
+
self.batches_complete = 0
|
| 188 |
+
self.batches_total = total
|
| 189 |
+
self.update_bar()
|
| 190 |
+
|
| 191 |
+
def update_batch_complete(self):
|
| 192 |
+
self.batches_complete += 1
|
| 193 |
+
self.update_bar()
|
| 194 |
+
|
| 195 |
+
def update_bar(self):
|
| 196 |
+
i = self.current_model
|
| 197 |
+
|
| 198 |
+
description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"
|
| 199 |
+
|
| 200 |
+
progress = i / self.total
|
| 201 |
+
if self.batches_total is not None:
|
| 202 |
+
progress += (self.batches_complete / self.batches_total) / self.total
|
| 203 |
+
|
| 204 |
+
self.progress_bar(progress=progress, desc=description)
|
| 205 |
+
|