Spaces:
Sleeping
Sleeping
Commit
·
f80827c
1
Parent(s):
b8c75ba
new build
Browse files- README.md +25 -6
- demo/__init__.py +0 -0
- demo/app.py +40 -61
- demo/utils.py +14 -30
- requirements.txt +0 -5
- runtime.txt +0 -1
README.md
CHANGED
|
@@ -1,14 +1,33 @@
|
|
| 1 |
---
|
| 2 |
title: TransformerRanker
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: yellow
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.44.0
|
| 8 |
-
app_file:
|
| 9 |
-
pinned:
|
| 10 |
license: mit
|
| 11 |
short_description: Efficient LM Ranking for Downstream Tasks
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: TransformerRanker
|
| 3 |
+
emoji: ⚡️
|
| 4 |
colorFrom: yellow
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.44.0
|
| 8 |
+
app_file: launch.py
|
| 9 |
+
pinned: true
|
| 10 |
license: mit
|
| 11 |
short_description: Efficient LM Ranking for Downstream Tasks
|
| 12 |
+
tags:
|
| 13 |
+
- ranking
|
| 14 |
+
- research
|
| 15 |
+
- evaluation
|
| 16 |
+
- llm
|
| 17 |
+
- demo
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## Run Demo
|
| 21 |
+
|
| 22 |
+
Install in editable mode:
|
| 23 |
+
|
| 24 |
+
``` bash
|
| 25 |
+
pip install -e .
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Launch it:
|
| 29 |
+
|
| 30 |
+
``` bash
|
| 31 |
+
python launch.py
|
| 32 |
+
```
|
| 33 |
|
|
|
demo/__init__.py
ADDED
|
File without changes
|
demo/app.py
CHANGED
|
@@ -2,11 +2,17 @@ import gradio as gr
|
|
| 2 |
from datasets import disable_caching, load_dataset
|
| 3 |
from transformer_ranker import TransformerRanker
|
| 4 |
|
| 5 |
-
from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS
|
| 6 |
from demo.utils import (
|
| 7 |
-
BANNER,
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
)
|
| 11 |
|
| 12 |
|
|
@@ -18,31 +24,33 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
| 18 |
|
| 19 |
##### 1. Load from datasets #####
|
| 20 |
|
| 21 |
-
gr.Markdown("## Load
|
| 22 |
|
| 23 |
gr.Markdown(
|
| 24 |
-
"
|
| 25 |
-
"This defines your downstream task."
|
| 26 |
)
|
| 27 |
|
| 28 |
with gr.Group():
|
| 29 |
dataset = gr.State(None)
|
| 30 |
|
| 31 |
dataset_id = gr.Textbox(
|
| 32 |
-
label="Dataset
|
| 33 |
placeholder="try: trec, conll2003, ag_news",
|
| 34 |
max_lines=1,
|
| 35 |
)
|
| 36 |
|
| 37 |
-
load_dataset_button = gr.Button(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# enable loading if dataset exists on hub
|
| 40 |
dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button)
|
| 41 |
|
| 42 |
gr.Markdown(
|
| 43 |
-
"
|
| 44 |
-
"
|
| 45 |
-
"or use the complete dataset with the [framework](https://github.com/flairNLP/transformer-ranker)."
|
| 46 |
)
|
| 47 |
|
| 48 |
##### data preprocessing #####
|
|
@@ -52,9 +60,7 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
| 52 |
dataset_id_label = gr.Label("", label="Dataset")
|
| 53 |
num_samples = gr.State(0)
|
| 54 |
num_samples_label = gr.Label("", label="Dataset size")
|
| 55 |
-
num_samples.change(
|
| 56 |
-
lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]
|
| 57 |
-
)
|
| 58 |
|
| 59 |
with gr.Row():
|
| 60 |
text_column = gr.Dropdown("", label="Text Column")
|
|
@@ -66,9 +72,7 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
| 66 |
|
| 67 |
with gr.Group():
|
| 68 |
downsample_ratio = gr.State(0.0)
|
| 69 |
-
sampling_rate = gr.Slider(
|
| 70 |
-
20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1
|
| 71 |
-
)
|
| 72 |
downsample_ratio_label = gr.Label("", label="Sampling rate")
|
| 73 |
downsample_ratio.change(
|
| 74 |
lambda x: f"{x:.1%}",
|
|
@@ -87,20 +91,14 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
| 87 |
outputs=downsample_ratio,
|
| 88 |
)
|
| 89 |
|
| 90 |
-
# load and show details
|
| 91 |
def load_hf_dataset(dataset_id):
|
| 92 |
try:
|
| 93 |
dataset = load_dataset(dataset_id, trust_remote_code=True)
|
| 94 |
dataset_details = preprocess_dataset(dataset)
|
| 95 |
except ValueError as e:
|
| 96 |
-
gr.Warning("
|
| 97 |
|
| 98 |
-
return (
|
| 99 |
-
gr.update(value="Loaded"),
|
| 100 |
-
dataset_id,
|
| 101 |
-
dataset,
|
| 102 |
-
*dataset_details
|
| 103 |
-
)
|
| 104 |
|
| 105 |
load_dataset_button.click(
|
| 106 |
load_hf_dataset,
|
|
@@ -120,29 +118,24 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
| 120 |
|
| 121 |
########## 2. Select LMs ##########
|
| 122 |
|
| 123 |
-
gr.Markdown("## Select Language Models")
|
| 124 |
|
| 125 |
gr.Markdown(
|
| 126 |
-
"Add two or more pretrained models
|
| 127 |
-
"
|
| 128 |
)
|
| 129 |
|
| 130 |
with gr.Group():
|
| 131 |
-
model_options = [
|
| 132 |
-
|
| 133 |
-
for model_handle in ALL_LMS
|
| 134 |
-
]
|
| 135 |
-
models = gr.CheckboxGroup(
|
| 136 |
-
choices=model_options, label="Model List", value=PRESELECTED_LMS
|
| 137 |
-
)
|
| 138 |
|
| 139 |
########## 3. Run ranking ##########
|
| 140 |
|
| 141 |
-
gr.Markdown("## Rank
|
| 142 |
|
| 143 |
gr.Markdown(
|
| 144 |
-
"Rank models by transferability to your
|
| 145 |
-
"
|
| 146 |
)
|
| 147 |
|
| 148 |
with gr.Group():
|
|
@@ -195,15 +188,13 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
| 195 |
):
|
| 196 |
|
| 197 |
if text_column == UNSET:
|
| 198 |
-
raise gr.Error("Text column is
|
| 199 |
|
| 200 |
if label_column == UNSET:
|
| 201 |
-
raise gr.Error("Label column is
|
| 202 |
|
| 203 |
if task_category == UNSET:
|
| 204 |
-
raise gr.Error(
|
| 205 |
-
"Task category not set. Dataset must support classification or regression."
|
| 206 |
-
)
|
| 207 |
|
| 208 |
if text_pair_column == UNSET:
|
| 209 |
text_pair_column = None
|
|
@@ -229,23 +220,19 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
| 229 |
tracker=tracker,
|
| 230 |
)
|
| 231 |
|
| 232 |
-
sorted_results = sorted(
|
| 233 |
-
|
| 234 |
-
)
|
| 235 |
-
return [
|
| 236 |
-
(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)
|
| 237 |
-
]
|
| 238 |
except Exception as e:
|
| 239 |
-
print(e)
|
| 240 |
gr.Warning(f"Ranking issue: {e}")
|
| 241 |
return []
|
| 242 |
|
| 243 |
-
gr.Markdown("
|
| 244 |
|
| 245 |
ranking_results = gr.Dataframe(
|
| 246 |
headers=["Rank", "Model", "Score"],
|
| 247 |
datatype=["number", "str", "number"],
|
| 248 |
-
value=[["-", "-", "-"]]
|
|
|
|
| 249 |
)
|
| 250 |
|
| 251 |
submit_button.click(
|
|
@@ -266,11 +253,3 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
| 266 |
)
|
| 267 |
|
| 268 |
gr.Markdown(FOOTER)
|
| 269 |
-
|
| 270 |
-
if __name__ == "__main__":
|
| 271 |
-
|
| 272 |
-
# run up to 3 requests at once
|
| 273 |
-
demo.queue(default_concurrency_limit=3)
|
| 274 |
-
|
| 275 |
-
# run with 6 workers
|
| 276 |
-
demo.launch(max_threads=6)
|
|
|
|
| 2 |
from datasets import disable_caching, load_dataset
|
| 3 |
from transformer_ranker import TransformerRanker
|
| 4 |
|
| 5 |
+
from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS
|
| 6 |
from demo.utils import (
|
| 7 |
+
BANNER,
|
| 8 |
+
FOOTER,
|
| 9 |
+
CSS,
|
| 10 |
+
UNSET,
|
| 11 |
+
EmbeddingProgressTracker,
|
| 12 |
+
compute_ratio,
|
| 13 |
+
validate_dataset,
|
| 14 |
+
preprocess_dataset,
|
| 15 |
+
ensure_dataset_is_loaded,
|
| 16 |
)
|
| 17 |
|
| 18 |
|
|
|
|
| 24 |
|
| 25 |
##### 1. Load from datasets #####
|
| 26 |
|
| 27 |
+
gr.Markdown("## 📚 Load Data")
|
| 28 |
|
| 29 |
gr.Markdown(
|
| 30 |
+
"Pick a dataset from the Hugging Face Hub (e.g. `trec`). This defines your downstream task."
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
with gr.Group():
|
| 34 |
dataset = gr.State(None)
|
| 35 |
|
| 36 |
dataset_id = gr.Textbox(
|
| 37 |
+
label="Dataset identifier",
|
| 38 |
placeholder="try: trec, conll2003, ag_news",
|
| 39 |
max_lines=1,
|
| 40 |
)
|
| 41 |
|
| 42 |
+
load_dataset_button = gr.Button(
|
| 43 |
+
value="Load data",
|
| 44 |
+
variant="primary",
|
| 45 |
+
interactive=True,
|
| 46 |
+
)
|
| 47 |
|
| 48 |
# enable loading if dataset exists on hub
|
| 49 |
dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button)
|
| 50 |
|
| 51 |
gr.Markdown(
|
| 52 |
+
"⚡️ Speed mode on: tweak the downsampling ratio in *Dataset Setup* for quicker runs. "
|
| 53 |
+
"Unlock the full data via [framework](https://github.com/flairNLP/transformer-ranker)."
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
##### data preprocessing #####
|
|
|
|
| 60 |
dataset_id_label = gr.Label("", label="Dataset")
|
| 61 |
num_samples = gr.State(0)
|
| 62 |
num_samples_label = gr.Label("", label="Dataset size")
|
| 63 |
+
num_samples.change(lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label])
|
|
|
|
|
|
|
| 64 |
|
| 65 |
with gr.Row():
|
| 66 |
text_column = gr.Dropdown("", label="Text Column")
|
|
|
|
| 72 |
|
| 73 |
with gr.Group():
|
| 74 |
downsample_ratio = gr.State(0.0)
|
| 75 |
+
sampling_rate = gr.Slider(20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1)
|
|
|
|
|
|
|
| 76 |
downsample_ratio_label = gr.Label("", label="Sampling rate")
|
| 77 |
downsample_ratio.change(
|
| 78 |
lambda x: f"{x:.1%}",
|
|
|
|
| 91 |
outputs=downsample_ratio,
|
| 92 |
)
|
| 93 |
|
|
|
|
| 94 |
def load_hf_dataset(dataset_id):
|
| 95 |
try:
|
| 96 |
dataset = load_dataset(dataset_id, trust_remote_code=True)
|
| 97 |
dataset_details = preprocess_dataset(dataset)
|
| 98 |
except ValueError as e:
|
| 99 |
+
gr.Warning(f"Watch out — single datasets only. Cannot load dataset: {e}")
|
| 100 |
|
| 101 |
+
return (gr.update(value="Loaded"), dataset_id, dataset, *dataset_details)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
load_dataset_button.click(
|
| 104 |
load_hf_dataset,
|
|
|
|
| 118 |
|
| 119 |
########## 2. Select LMs ##########
|
| 120 |
|
| 121 |
+
gr.Markdown("## 🧠 Select Language Models")
|
| 122 |
|
| 123 |
gr.Markdown(
|
| 124 |
+
"Add two or more pretrained models to compare. "
|
| 125 |
+
"Stick to smaller models here since the demo runs on CPU."
|
| 126 |
)
|
| 127 |
|
| 128 |
with gr.Group():
|
| 129 |
+
model_options = [(model_handle.split("/")[-1], model_handle) for model_handle in ALL_LMS]
|
| 130 |
+
models = gr.CheckboxGroup(choices=model_options, label="Model List", value=PRESELECTED_LMS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
########## 3. Run ranking ##########
|
| 133 |
|
| 134 |
+
gr.Markdown("## 🏆 Rank Models")
|
| 135 |
|
| 136 |
gr.Markdown(
|
| 137 |
+
"Rank models by transferability to your task. "
|
| 138 |
+
"More control? Tweak transferability metric and layer aggregation in *Settings*."
|
| 139 |
)
|
| 140 |
|
| 141 |
with gr.Group():
|
|
|
|
| 188 |
):
|
| 189 |
|
| 190 |
if text_column == UNSET:
|
| 191 |
+
raise gr.Error("Text column is required.")
|
| 192 |
|
| 193 |
if label_column == UNSET:
|
| 194 |
+
raise gr.Error("Label column is required.")
|
| 195 |
|
| 196 |
if task_category == UNSET:
|
| 197 |
+
raise gr.Error("Task category is required.")
|
|
|
|
|
|
|
| 198 |
|
| 199 |
if text_pair_column == UNSET:
|
| 200 |
text_pair_column = None
|
|
|
|
| 220 |
tracker=tracker,
|
| 221 |
)
|
| 222 |
|
| 223 |
+
sorted_results = sorted(results._results.items(), key=lambda item: item[1], reverse=True)
|
| 224 |
+
return [(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
except Exception as e:
|
|
|
|
| 226 |
gr.Warning(f"Ranking issue: {e}")
|
| 227 |
return []
|
| 228 |
|
| 229 |
+
gr.Markdown("**Leaderboard:** higher score → better downstream performance.")
|
| 230 |
|
| 231 |
ranking_results = gr.Dataframe(
|
| 232 |
headers=["Rank", "Model", "Score"],
|
| 233 |
datatype=["number", "str", "number"],
|
| 234 |
+
value=[["-", "-", "-"]],
|
| 235 |
+
interactive=False
|
| 236 |
)
|
| 237 |
|
| 238 |
submit_button.click(
|
|
|
|
| 253 |
)
|
| 254 |
|
| 255 |
gr.Markdown(FOOTER)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/utils.py
CHANGED
|
@@ -9,11 +9,11 @@ from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
|
|
| 9 |
from transformer_ranker.embedder import Embedder
|
| 10 |
|
| 11 |
BANNER = """
|
| 12 |
-
<h1 align="center"
|
| 13 |
|
| 14 |
<p align="center" style="max-width: 560px; margin: auto;">
|
| 15 |
Find the best language model for your downstream task.
|
| 16 |
-
Load a dataset,
|
| 17 |
</p>
|
| 18 |
|
| 19 |
<p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
|
|
@@ -35,8 +35,9 @@ BANNER = """
|
|
| 35 |
"""
|
| 36 |
|
| 37 |
FOOTER = """
|
| 38 |
-
**Note:** CPU-only
|
| 39 |
-
**
|
|
|
|
| 40 |
"""
|
| 41 |
|
| 42 |
CSS = """
|
|
@@ -53,18 +54,18 @@ preprocessing = DatasetCleaner()
|
|
| 53 |
|
| 54 |
|
| 55 |
def validate_dataset(dataset_name):
|
| 56 |
-
"""
|
| 57 |
try:
|
| 58 |
-
hf_api.dataset_info(dataset_name)
|
| 59 |
return gr.update(interactive=True)
|
| 60 |
-
|
| 61 |
except (HTTPError, HFValidationError):
|
| 62 |
return gr.update(value="Load data", interactive=False)
|
| 63 |
|
| 64 |
|
| 65 |
def preprocess_dataset(dataset):
|
| 66 |
-
"""
|
| 67 |
data = concatenate_datasets(list(dataset.values()))
|
|
|
|
| 68 |
|
| 69 |
try:
|
| 70 |
text_column = preprocessing._find_column(data, "text column")
|
|
@@ -83,19 +84,11 @@ def preprocess_dataset(dataset):
|
|
| 83 |
try:
|
| 84 |
task_category = preprocessing._find_task_category(data, label_column)
|
| 85 |
except ValueError:
|
| 86 |
-
gr.Warning(
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
label_column = gr.update(value=label_column, choices=data.column_names, interactive=True)
|
| 90 |
-
text_pair = gr.update(value=UNSET, choices=[UNSET, *data.column_names], interactive=True)
|
| 91 |
-
task_category = gr.update(value=task_category, choices=[str(t) for t in TaskCategory], interactive=True)
|
| 92 |
-
sample_size = len(data)
|
| 93 |
-
|
| 94 |
-
return task_category, text_column, text_pair, label_column, sample_size
|
| 95 |
|
| 96 |
-
"""
|
| 97 |
return (
|
| 98 |
-
text_column,
|
| 99 |
gr.update(
|
| 100 |
value=task_category,
|
| 101 |
choices=[str(t) for t in TaskCategory],
|
|
@@ -110,9 +103,9 @@ def preprocess_dataset(dataset):
|
|
| 110 |
gr.update(
|
| 111 |
value=label_column, choices=data.column_names, interactive=True
|
| 112 |
),
|
| 113 |
-
|
| 114 |
)
|
| 115 |
-
|
| 116 |
|
| 117 |
def compute_ratio(num_samples_to_use, num_samples):
|
| 118 |
if num_samples > 0:
|
|
@@ -128,12 +121,6 @@ def ensure_dataset_is_loaded(dataset, text_column, label_column, task_category):
|
|
| 128 |
return gr.update(interactive=False)
|
| 129 |
|
| 130 |
|
| 131 |
-
def ensure_one_lm_selected(checkbox_values, previous_values):
|
| 132 |
-
if not any(checkbox_values):
|
| 133 |
-
return previous_values
|
| 134 |
-
return checkbox_values
|
| 135 |
-
|
| 136 |
-
|
| 137 |
# apply monkey patch to enable callbacks
|
| 138 |
_old_embed = Embedder.embed
|
| 139 |
|
|
@@ -185,8 +172,6 @@ class EmbeddingProgressTracker:
|
|
| 185 |
self.progress_bar(1.0, desc="Done")
|
| 186 |
else:
|
| 187 |
self.progress_bar(1.0, desc="Error")
|
| 188 |
-
|
| 189 |
-
# Do not suppress any errors
|
| 190 |
return False
|
| 191 |
|
| 192 |
def update_num_batches(self, total):
|
|
@@ -201,7 +186,6 @@ class EmbeddingProgressTracker:
|
|
| 201 |
|
| 202 |
def update_bar(self):
|
| 203 |
i = self.current_model
|
| 204 |
-
|
| 205 |
description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"
|
| 206 |
|
| 207 |
progress = i / self.total
|
|
|
|
| 9 |
from transformer_ranker.embedder import Embedder
|
| 10 |
|
| 11 |
BANNER = """
|
| 12 |
+
<h1 align="center">🚀 TransformerRanker ⚡️</h1>
|
| 13 |
|
| 14 |
<p align="center" style="max-width: 560px; margin: auto;">
|
| 15 |
Find the best language model for your downstream task.
|
| 16 |
+
Load a dataset, pick models from the 🤗 Hub, and rank them by <strong>transferability</strong>.
|
| 17 |
</p>
|
| 18 |
|
| 19 |
<p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
|
|
|
|
| 35 |
"""
|
| 36 |
|
| 37 |
FOOTER = """
|
| 38 |
+
**Note:** Quick CPU-only demo.
|
| 39 |
+
**Built by** [@lukasgarbas](https://huggingface.co/lukasgarbas) & [@plonerma](https://huggingface.co/plonerma)
|
| 40 |
+
**Questions?** Open a [GitHub issue](https://github.com/flairNLP/transformer-ranker/issues) 🔫
|
| 41 |
"""
|
| 42 |
|
| 43 |
CSS = """
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
def validate_dataset(dataset_name):
|
| 57 |
+
"""Quick look dataset existence on Hub."""
|
| 58 |
try:
|
| 59 |
+
hf_api.dataset_info(dataset_name)
|
| 60 |
return gr.update(interactive=True)
|
|
|
|
| 61 |
except (HTTPError, HFValidationError):
|
| 62 |
return gr.update(value="Load data", interactive=False)
|
| 63 |
|
| 64 |
|
| 65 |
def preprocess_dataset(dataset):
|
| 66 |
+
"""Detect text/label columns and task type."""
|
| 67 |
data = concatenate_datasets(list(dataset.values()))
|
| 68 |
+
sample_size = len(data)
|
| 69 |
|
| 70 |
try:
|
| 71 |
text_column = preprocessing._find_column(data, "text column")
|
|
|
|
| 84 |
try:
|
| 85 |
task_category = preprocessing._find_task_category(data, label_column)
|
| 86 |
except ValueError:
|
| 87 |
+
gr.Warning(
|
| 88 |
+
"Task category not auto-detected — framework supports classification, regression."
|
| 89 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
|
|
|
| 91 |
return (
|
|
|
|
| 92 |
gr.update(
|
| 93 |
value=task_category,
|
| 94 |
choices=[str(t) for t in TaskCategory],
|
|
|
|
| 103 |
gr.update(
|
| 104 |
value=label_column, choices=data.column_names, interactive=True
|
| 105 |
),
|
| 106 |
+
sample_size,
|
| 107 |
)
|
| 108 |
+
|
| 109 |
|
| 110 |
def compute_ratio(num_samples_to_use, num_samples):
|
| 111 |
if num_samples > 0:
|
|
|
|
| 121 |
return gr.update(interactive=False)
|
| 122 |
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
# apply monkey patch to enable callbacks
|
| 125 |
_old_embed = Embedder.embed
|
| 126 |
|
|
|
|
| 172 |
self.progress_bar(1.0, desc="Done")
|
| 173 |
else:
|
| 174 |
self.progress_bar(1.0, desc="Error")
|
|
|
|
|
|
|
| 175 |
return False
|
| 176 |
|
| 177 |
def update_num_batches(self, total):
|
|
|
|
| 186 |
|
| 187 |
def update_bar(self):
|
| 188 |
i = self.current_model
|
|
|
|
| 189 |
description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"
|
| 190 |
|
| 191 |
progress = i / self.total
|
requirements.txt
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
gradio>=5.0
|
| 2 |
-
transformer-ranker==0.1.2
|
| 3 |
-
transformers==4.41.0
|
| 4 |
-
datasets==3.6
|
| 5 |
-
protobuf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
runtime.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
python-3.12
|
|
|
|
|
|