Spaces:
Running
Running
Commit
·
f80827c
1
Parent(s):
b8c75ba
new build
Browse files- README.md +25 -6
- demo/__init__.py +0 -0
- demo/app.py +40 -61
- demo/utils.py +14 -30
- requirements.txt +0 -5
- runtime.txt +0 -1
README.md
CHANGED
@@ -1,14 +1,33 @@
|
|
1 |
---
|
2 |
title: TransformerRanker
|
3 |
-
emoji:
|
4 |
colorFrom: yellow
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.44.0
|
8 |
-
app_file:
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
short_description: Efficient LM Ranking for Downstream Tasks
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: TransformerRanker
|
3 |
+
emoji: ⚡️
|
4 |
colorFrom: yellow
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.44.0
|
8 |
+
app_file: launch.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
short_description: Efficient LM Ranking for Downstream Tasks
|
12 |
+
tags:
|
13 |
+
- ranking
|
14 |
+
- research
|
15 |
+
- evaluation
|
16 |
+
- llm
|
17 |
+
- demo
|
18 |
+
---
|
19 |
+
|
20 |
+
## Run Demo
|
21 |
+
|
22 |
+
Install in editable mode:
|
23 |
+
|
24 |
+
``` bash
|
25 |
+
pip install -e .
|
26 |
+
```
|
27 |
+
|
28 |
+
Launch it:
|
29 |
+
|
30 |
+
``` bash
|
31 |
+
python launch.py
|
32 |
+
```
|
33 |
|
|
demo/__init__.py
ADDED
File without changes
|
demo/app.py
CHANGED
@@ -2,11 +2,17 @@ import gradio as gr
|
|
2 |
from datasets import disable_caching, load_dataset
|
3 |
from transformer_ranker import TransformerRanker
|
4 |
|
5 |
-
from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS
|
6 |
from demo.utils import (
|
7 |
-
BANNER,
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
)
|
11 |
|
12 |
|
@@ -18,31 +24,33 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
18 |
|
19 |
##### 1. Load from datasets #####
|
20 |
|
21 |
-
gr.Markdown("## Load
|
22 |
|
23 |
gr.Markdown(
|
24 |
-
"
|
25 |
-
"This defines your downstream task."
|
26 |
)
|
27 |
|
28 |
with gr.Group():
|
29 |
dataset = gr.State(None)
|
30 |
|
31 |
dataset_id = gr.Textbox(
|
32 |
-
label="Dataset
|
33 |
placeholder="try: trec, conll2003, ag_news",
|
34 |
max_lines=1,
|
35 |
)
|
36 |
|
37 |
-
load_dataset_button = gr.Button(
|
|
|
|
|
|
|
|
|
38 |
|
39 |
# enable loading if dataset exists on hub
|
40 |
dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button)
|
41 |
|
42 |
gr.Markdown(
|
43 |
-
"
|
44 |
-
"
|
45 |
-
"or use the complete dataset with the [framework](https://github.com/flairNLP/transformer-ranker)."
|
46 |
)
|
47 |
|
48 |
##### data preprocessing #####
|
@@ -52,9 +60,7 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
52 |
dataset_id_label = gr.Label("", label="Dataset")
|
53 |
num_samples = gr.State(0)
|
54 |
num_samples_label = gr.Label("", label="Dataset size")
|
55 |
-
num_samples.change(
|
56 |
-
lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]
|
57 |
-
)
|
58 |
|
59 |
with gr.Row():
|
60 |
text_column = gr.Dropdown("", label="Text Column")
|
@@ -66,9 +72,7 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
66 |
|
67 |
with gr.Group():
|
68 |
downsample_ratio = gr.State(0.0)
|
69 |
-
sampling_rate = gr.Slider(
|
70 |
-
20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1
|
71 |
-
)
|
72 |
downsample_ratio_label = gr.Label("", label="Sampling rate")
|
73 |
downsample_ratio.change(
|
74 |
lambda x: f"{x:.1%}",
|
@@ -87,20 +91,14 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
87 |
outputs=downsample_ratio,
|
88 |
)
|
89 |
|
90 |
-
# load and show details
|
91 |
def load_hf_dataset(dataset_id):
|
92 |
try:
|
93 |
dataset = load_dataset(dataset_id, trust_remote_code=True)
|
94 |
dataset_details = preprocess_dataset(dataset)
|
95 |
except ValueError as e:
|
96 |
-
gr.Warning("
|
97 |
|
98 |
-
return (
|
99 |
-
gr.update(value="Loaded"),
|
100 |
-
dataset_id,
|
101 |
-
dataset,
|
102 |
-
*dataset_details
|
103 |
-
)
|
104 |
|
105 |
load_dataset_button.click(
|
106 |
load_hf_dataset,
|
@@ -120,29 +118,24 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
120 |
|
121 |
########## 2. Select LMs ##########
|
122 |
|
123 |
-
gr.Markdown("## Select Language Models")
|
124 |
|
125 |
gr.Markdown(
|
126 |
-
"Add two or more pretrained models
|
127 |
-
"
|
128 |
)
|
129 |
|
130 |
with gr.Group():
|
131 |
-
model_options = [
|
132 |
-
|
133 |
-
for model_handle in ALL_LMS
|
134 |
-
]
|
135 |
-
models = gr.CheckboxGroup(
|
136 |
-
choices=model_options, label="Model List", value=PRESELECTED_LMS
|
137 |
-
)
|
138 |
|
139 |
########## 3. Run ranking ##########
|
140 |
|
141 |
-
gr.Markdown("## Rank
|
142 |
|
143 |
gr.Markdown(
|
144 |
-
"Rank models by transferability to your
|
145 |
-
"
|
146 |
)
|
147 |
|
148 |
with gr.Group():
|
@@ -195,15 +188,13 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
195 |
):
|
196 |
|
197 |
if text_column == UNSET:
|
198 |
-
raise gr.Error("Text column is
|
199 |
|
200 |
if label_column == UNSET:
|
201 |
-
raise gr.Error("Label column is
|
202 |
|
203 |
if task_category == UNSET:
|
204 |
-
raise gr.Error(
|
205 |
-
"Task category not set. Dataset must support classification or regression."
|
206 |
-
)
|
207 |
|
208 |
if text_pair_column == UNSET:
|
209 |
text_pair_column = None
|
@@ -229,23 +220,19 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
229 |
tracker=tracker,
|
230 |
)
|
231 |
|
232 |
-
sorted_results = sorted(
|
233 |
-
|
234 |
-
)
|
235 |
-
return [
|
236 |
-
(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)
|
237 |
-
]
|
238 |
except Exception as e:
|
239 |
-
print(e)
|
240 |
gr.Warning(f"Ranking issue: {e}")
|
241 |
return []
|
242 |
|
243 |
-
gr.Markdown("
|
244 |
|
245 |
ranking_results = gr.Dataframe(
|
246 |
headers=["Rank", "Model", "Score"],
|
247 |
datatype=["number", "str", "number"],
|
248 |
-
value=[["-", "-", "-"]]
|
|
|
249 |
)
|
250 |
|
251 |
submit_button.click(
|
@@ -266,11 +253,3 @@ with gr.Blocks(css=CSS, theme=None) as demo:
|
|
266 |
)
|
267 |
|
268 |
gr.Markdown(FOOTER)
|
269 |
-
|
270 |
-
if __name__ == "__main__":
|
271 |
-
|
272 |
-
# run up to 3 requests at once
|
273 |
-
demo.queue(default_concurrency_limit=3)
|
274 |
-
|
275 |
-
# run with 6 workers
|
276 |
-
demo.launch(max_threads=6)
|
|
|
2 |
from datasets import disable_caching, load_dataset
|
3 |
from transformer_ranker import TransformerRanker
|
4 |
|
5 |
+
from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS
|
6 |
from demo.utils import (
|
7 |
+
BANNER,
|
8 |
+
FOOTER,
|
9 |
+
CSS,
|
10 |
+
UNSET,
|
11 |
+
EmbeddingProgressTracker,
|
12 |
+
compute_ratio,
|
13 |
+
validate_dataset,
|
14 |
+
preprocess_dataset,
|
15 |
+
ensure_dataset_is_loaded,
|
16 |
)
|
17 |
|
18 |
|
|
|
24 |
|
25 |
##### 1. Load from datasets #####
|
26 |
|
27 |
+
gr.Markdown("## 📚 Load Data")
|
28 |
|
29 |
gr.Markdown(
|
30 |
+
"Pick a dataset from the Hugging Face Hub (e.g. `trec`). This defines your downstream task."
|
|
|
31 |
)
|
32 |
|
33 |
with gr.Group():
|
34 |
dataset = gr.State(None)
|
35 |
|
36 |
dataset_id = gr.Textbox(
|
37 |
+
label="Dataset identifier",
|
38 |
placeholder="try: trec, conll2003, ag_news",
|
39 |
max_lines=1,
|
40 |
)
|
41 |
|
42 |
+
load_dataset_button = gr.Button(
|
43 |
+
value="Load data",
|
44 |
+
variant="primary",
|
45 |
+
interactive=True,
|
46 |
+
)
|
47 |
|
48 |
# enable loading if dataset exists on hub
|
49 |
dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button)
|
50 |
|
51 |
gr.Markdown(
|
52 |
+
"⚡️ Speed mode on: tweak the downsampling ratio in *Dataset Setup* for quicker runs. "
|
53 |
+
"Unlock the full data via [framework](https://github.com/flairNLP/transformer-ranker)."
|
|
|
54 |
)
|
55 |
|
56 |
##### data preprocessing #####
|
|
|
60 |
dataset_id_label = gr.Label("", label="Dataset")
|
61 |
num_samples = gr.State(0)
|
62 |
num_samples_label = gr.Label("", label="Dataset size")
|
63 |
+
num_samples.change(lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label])
|
|
|
|
|
64 |
|
65 |
with gr.Row():
|
66 |
text_column = gr.Dropdown("", label="Text Column")
|
|
|
72 |
|
73 |
with gr.Group():
|
74 |
downsample_ratio = gr.State(0.0)
|
75 |
+
sampling_rate = gr.Slider(20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1)
|
|
|
|
|
76 |
downsample_ratio_label = gr.Label("", label="Sampling rate")
|
77 |
downsample_ratio.change(
|
78 |
lambda x: f"{x:.1%}",
|
|
|
91 |
outputs=downsample_ratio,
|
92 |
)
|
93 |
|
|
|
94 |
def load_hf_dataset(dataset_id):
|
95 |
try:
|
96 |
dataset = load_dataset(dataset_id, trust_remote_code=True)
|
97 |
dataset_details = preprocess_dataset(dataset)
|
98 |
except ValueError as e:
|
99 |
+
gr.Warning(f"Watch out — single datasets only. Cannot load dataset: {e}")
|
100 |
|
101 |
+
return (gr.update(value="Loaded"), dataset_id, dataset, *dataset_details)
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
load_dataset_button.click(
|
104 |
load_hf_dataset,
|
|
|
118 |
|
119 |
########## 2. Select LMs ##########
|
120 |
|
121 |
+
gr.Markdown("## 🧠 Select Language Models")
|
122 |
|
123 |
gr.Markdown(
|
124 |
+
"Add two or more pretrained models to compare. "
|
125 |
+
"Stick to smaller models here since the demo runs on CPU."
|
126 |
)
|
127 |
|
128 |
with gr.Group():
|
129 |
+
model_options = [(model_handle.split("/")[-1], model_handle) for model_handle in ALL_LMS]
|
130 |
+
models = gr.CheckboxGroup(choices=model_options, label="Model List", value=PRESELECTED_LMS)
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
########## 3. Run ranking ##########
|
133 |
|
134 |
+
gr.Markdown("## 🏆 Rank Models")
|
135 |
|
136 |
gr.Markdown(
|
137 |
+
"Rank models by transferability to your task. "
|
138 |
+
"More control? Tweak transferability metric and layer aggregation in *Settings*."
|
139 |
)
|
140 |
|
141 |
with gr.Group():
|
|
|
188 |
):
|
189 |
|
190 |
if text_column == UNSET:
|
191 |
+
raise gr.Error("Text column is required.")
|
192 |
|
193 |
if label_column == UNSET:
|
194 |
+
raise gr.Error("Label column is required.")
|
195 |
|
196 |
if task_category == UNSET:
|
197 |
+
raise gr.Error("Task category is required.")
|
|
|
|
|
198 |
|
199 |
if text_pair_column == UNSET:
|
200 |
text_pair_column = None
|
|
|
220 |
tracker=tracker,
|
221 |
)
|
222 |
|
223 |
+
sorted_results = sorted(results._results.items(), key=lambda item: item[1], reverse=True)
|
224 |
+
return [(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)]
|
|
|
|
|
|
|
|
|
225 |
except Exception as e:
|
|
|
226 |
gr.Warning(f"Ranking issue: {e}")
|
227 |
return []
|
228 |
|
229 |
+
gr.Markdown("**Leaderboard:** higher score → better downstream performance.")
|
230 |
|
231 |
ranking_results = gr.Dataframe(
|
232 |
headers=["Rank", "Model", "Score"],
|
233 |
datatype=["number", "str", "number"],
|
234 |
+
value=[["-", "-", "-"]],
|
235 |
+
interactive=False
|
236 |
)
|
237 |
|
238 |
submit_button.click(
|
|
|
253 |
)
|
254 |
|
255 |
gr.Markdown(FOOTER)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/utils.py
CHANGED
@@ -9,11 +9,11 @@ from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
|
|
9 |
from transformer_ranker.embedder import Embedder
|
10 |
|
11 |
BANNER = """
|
12 |
-
<h1 align="center"
|
13 |
|
14 |
<p align="center" style="max-width: 560px; margin: auto;">
|
15 |
Find the best language model for your downstream task.
|
16 |
-
Load a dataset,
|
17 |
</p>
|
18 |
|
19 |
<p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
|
@@ -35,8 +35,9 @@ BANNER = """
|
|
35 |
"""
|
36 |
|
37 |
FOOTER = """
|
38 |
-
**Note:** CPU-only
|
39 |
-
**
|
|
|
40 |
"""
|
41 |
|
42 |
CSS = """
|
@@ -53,18 +54,18 @@ preprocessing = DatasetCleaner()
|
|
53 |
|
54 |
|
55 |
def validate_dataset(dataset_name):
|
56 |
-
"""
|
57 |
try:
|
58 |
-
hf_api.dataset_info(dataset_name)
|
59 |
return gr.update(interactive=True)
|
60 |
-
|
61 |
except (HTTPError, HFValidationError):
|
62 |
return gr.update(value="Load data", interactive=False)
|
63 |
|
64 |
|
65 |
def preprocess_dataset(dataset):
|
66 |
-
"""
|
67 |
data = concatenate_datasets(list(dataset.values()))
|
|
|
68 |
|
69 |
try:
|
70 |
text_column = preprocessing._find_column(data, "text column")
|
@@ -83,19 +84,11 @@ def preprocess_dataset(dataset):
|
|
83 |
try:
|
84 |
task_category = preprocessing._find_task_category(data, label_column)
|
85 |
except ValueError:
|
86 |
-
gr.Warning(
|
87 |
-
|
88 |
-
|
89 |
-
label_column = gr.update(value=label_column, choices=data.column_names, interactive=True)
|
90 |
-
text_pair = gr.update(value=UNSET, choices=[UNSET, *data.column_names], interactive=True)
|
91 |
-
task_category = gr.update(value=task_category, choices=[str(t) for t in TaskCategory], interactive=True)
|
92 |
-
sample_size = len(data)
|
93 |
-
|
94 |
-
return task_category, text_column, text_pair, label_column, sample_size
|
95 |
|
96 |
-
"""
|
97 |
return (
|
98 |
-
text_column,
|
99 |
gr.update(
|
100 |
value=task_category,
|
101 |
choices=[str(t) for t in TaskCategory],
|
@@ -110,9 +103,9 @@ def preprocess_dataset(dataset):
|
|
110 |
gr.update(
|
111 |
value=label_column, choices=data.column_names, interactive=True
|
112 |
),
|
113 |
-
|
114 |
)
|
115 |
-
|
116 |
|
117 |
def compute_ratio(num_samples_to_use, num_samples):
|
118 |
if num_samples > 0:
|
@@ -128,12 +121,6 @@ def ensure_dataset_is_loaded(dataset, text_column, label_column, task_category):
|
|
128 |
return gr.update(interactive=False)
|
129 |
|
130 |
|
131 |
-
def ensure_one_lm_selected(checkbox_values, previous_values):
|
132 |
-
if not any(checkbox_values):
|
133 |
-
return previous_values
|
134 |
-
return checkbox_values
|
135 |
-
|
136 |
-
|
137 |
# apply monkey patch to enable callbacks
|
138 |
_old_embed = Embedder.embed
|
139 |
|
@@ -185,8 +172,6 @@ class EmbeddingProgressTracker:
|
|
185 |
self.progress_bar(1.0, desc="Done")
|
186 |
else:
|
187 |
self.progress_bar(1.0, desc="Error")
|
188 |
-
|
189 |
-
# Do not suppress any errors
|
190 |
return False
|
191 |
|
192 |
def update_num_batches(self, total):
|
@@ -201,7 +186,6 @@ class EmbeddingProgressTracker:
|
|
201 |
|
202 |
def update_bar(self):
|
203 |
i = self.current_model
|
204 |
-
|
205 |
description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"
|
206 |
|
207 |
progress = i / self.total
|
|
|
9 |
from transformer_ranker.embedder import Embedder
|
10 |
|
11 |
BANNER = """
|
12 |
+
<h1 align="center">🚀 TransformerRanker ⚡️</h1>
|
13 |
|
14 |
<p align="center" style="max-width: 560px; margin: auto;">
|
15 |
Find the best language model for your downstream task.
|
16 |
+
Load a dataset, pick models from the 🤗 Hub, and rank them by <strong>transferability</strong>.
|
17 |
</p>
|
18 |
|
19 |
<p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
|
|
|
35 |
"""
|
36 |
|
37 |
FOOTER = """
|
38 |
+
**Note:** Quick CPU-only demo.
|
39 |
+
**Built by** [@lukasgarbas](https://huggingface.co/lukasgarbas) & [@plonerma](https://huggingface.co/plonerma)
|
40 |
+
**Questions?** Open a [GitHub issue](https://github.com/flairNLP/transformer-ranker/issues) 🔫
|
41 |
"""
|
42 |
|
43 |
CSS = """
|
|
|
54 |
|
55 |
|
56 |
def validate_dataset(dataset_name):
|
57 |
+
"""Quick look dataset existence on Hub."""
|
58 |
try:
|
59 |
+
hf_api.dataset_info(dataset_name)
|
60 |
return gr.update(interactive=True)
|
|
|
61 |
except (HTTPError, HFValidationError):
|
62 |
return gr.update(value="Load data", interactive=False)
|
63 |
|
64 |
|
65 |
def preprocess_dataset(dataset):
|
66 |
+
"""Detect text/label columns and task type."""
|
67 |
data = concatenate_datasets(list(dataset.values()))
|
68 |
+
sample_size = len(data)
|
69 |
|
70 |
try:
|
71 |
text_column = preprocessing._find_column(data, "text column")
|
|
|
84 |
try:
|
85 |
task_category = preprocessing._find_task_category(data, label_column)
|
86 |
except ValueError:
|
87 |
+
gr.Warning(
|
88 |
+
"Task category not auto-detected — framework supports classification, regression."
|
89 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
|
|
91 |
return (
|
|
|
92 |
gr.update(
|
93 |
value=task_category,
|
94 |
choices=[str(t) for t in TaskCategory],
|
|
|
103 |
gr.update(
|
104 |
value=label_column, choices=data.column_names, interactive=True
|
105 |
),
|
106 |
+
sample_size,
|
107 |
)
|
108 |
+
|
109 |
|
110 |
def compute_ratio(num_samples_to_use, num_samples):
|
111 |
if num_samples > 0:
|
|
|
121 |
return gr.update(interactive=False)
|
122 |
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
# apply monkey patch to enable callbacks
|
125 |
_old_embed = Embedder.embed
|
126 |
|
|
|
172 |
self.progress_bar(1.0, desc="Done")
|
173 |
else:
|
174 |
self.progress_bar(1.0, desc="Error")
|
|
|
|
|
175 |
return False
|
176 |
|
177 |
def update_num_batches(self, total):
|
|
|
186 |
|
187 |
def update_bar(self):
|
188 |
i = self.current_model
|
|
|
189 |
description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"
|
190 |
|
191 |
progress = i / self.total
|
requirements.txt
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
gradio>=5.0
|
2 |
-
transformer-ranker==0.1.2
|
3 |
-
transformers==4.41.0
|
4 |
-
datasets==3.6
|
5 |
-
protobuf
|
|
|
|
|
|
|
|
|
|
|
|
runtime.txt
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
python-3.12
|
|
|
|