lukasgarbas commited on
Commit
f80827c
·
1 Parent(s): b8c75ba
Files changed (6) hide show
  1. README.md +25 -6
  2. demo/__init__.py +0 -0
  3. demo/app.py +40 -61
  4. demo/utils.py +14 -30
  5. requirements.txt +0 -5
  6. runtime.txt +0 -1
README.md CHANGED
@@ -1,14 +1,33 @@
1
  ---
2
  title: TransformerRanker
3
- emoji: 🎯
4
  colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.44.0
8
- app_file: demo/app.py
9
- pinned: false
10
  license: mit
11
  short_description: Efficient LM Ranking for Downstream Tasks
12
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: TransformerRanker
3
+ emoji: ⚡️
4
  colorFrom: yellow
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.44.0
8
+ app_file: launch.py
9
+ pinned: true
10
  license: mit
11
  short_description: Efficient LM Ranking for Downstream Tasks
12
+ tags:
13
+ - ranking
14
+ - research
15
+ - evaluation
16
+ - llm
17
+ - demo
18
+ ---
19
+
20
+ ## Run Demo
21
+
22
+ Install in editable mode:
23
+
24
+ ``` bash
25
+ pip install -e .
26
+ ```
27
+
28
+ Launch it:
29
+
30
+ ``` bash
31
+ python launch.py
32
+ ```
33
 
 
demo/__init__.py ADDED
File without changes
demo/app.py CHANGED
@@ -2,11 +2,17 @@ import gradio as gr
2
  from datasets import disable_caching, load_dataset
3
  from transformer_ranker import TransformerRanker
4
 
5
- from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS, GRADIO_THEME
6
  from demo.utils import (
7
- BANNER, FOOTER, CSS, UNSET,
8
- EmbeddingProgressTracker, compute_ratio,
9
- validate_dataset, preprocess_dataset, ensure_dataset_is_loaded
 
 
 
 
 
 
10
  )
11
 
12
 
@@ -18,31 +24,33 @@ with gr.Blocks(css=CSS, theme=None) as demo:
18
 
19
  ##### 1. Load from datasets #####
20
 
21
- gr.Markdown("## Load Downstream Dataset")
22
 
23
  gr.Markdown(
24
- "Select a dataset from the Hugging Face Hub such as `trec`. "
25
- "This defines your downstream task."
26
  )
27
 
28
  with gr.Group():
29
  dataset = gr.State(None)
30
 
31
  dataset_id = gr.Textbox(
32
- label="Dataset name",
33
  placeholder="try: trec, conll2003, ag_news",
34
  max_lines=1,
35
  )
36
 
37
- load_dataset_button = gr.Button(value="Load data", variant="primary", interactive=True,)
 
 
 
 
38
 
39
  # enable loading if dataset exists on hub
40
  dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button)
41
 
42
  gr.Markdown(
43
- "Settings auto-configured. "
44
- "Adjust the downsampling ratio in Dataset Setup, "
45
- "or use the complete dataset with the [framework](https://github.com/flairNLP/transformer-ranker)."
46
  )
47
 
48
  ##### data preprocessing #####
@@ -52,9 +60,7 @@ with gr.Blocks(css=CSS, theme=None) as demo:
52
  dataset_id_label = gr.Label("", label="Dataset")
53
  num_samples = gr.State(0)
54
  num_samples_label = gr.Label("", label="Dataset size")
55
- num_samples.change(
56
- lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]
57
- )
58
 
59
  with gr.Row():
60
  text_column = gr.Dropdown("", label="Text Column")
@@ -66,9 +72,7 @@ with gr.Blocks(css=CSS, theme=None) as demo:
66
 
67
  with gr.Group():
68
  downsample_ratio = gr.State(0.0)
69
- sampling_rate = gr.Slider(
70
- 20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1
71
- )
72
  downsample_ratio_label = gr.Label("", label="Sampling rate")
73
  downsample_ratio.change(
74
  lambda x: f"{x:.1%}",
@@ -87,20 +91,14 @@ with gr.Blocks(css=CSS, theme=None) as demo:
87
  outputs=downsample_ratio,
88
  )
89
 
90
- # load and show details
91
  def load_hf_dataset(dataset_id):
92
  try:
93
  dataset = load_dataset(dataset_id, trust_remote_code=True)
94
  dataset_details = preprocess_dataset(dataset)
95
  except ValueError as e:
96
- gr.Warning("Collections not supported. Load one dataset only.")
97
 
98
- return (
99
- gr.update(value="Loaded"),
100
- dataset_id,
101
- dataset,
102
- *dataset_details
103
- )
104
 
105
  load_dataset_button.click(
106
  load_hf_dataset,
@@ -120,29 +118,24 @@ with gr.Blocks(css=CSS, theme=None) as demo:
120
 
121
  ########## 2. Select LMs ##########
122
 
123
- gr.Markdown("## Select Language Models")
124
 
125
  gr.Markdown(
126
- "Add two or more pretrained models for ranking. "
127
- "Go with small models since this demo runs on CPU."
128
  )
129
 
130
  with gr.Group():
131
- model_options = [
132
- (model_handle.split("/")[-1], model_handle)
133
- for model_handle in ALL_LMS
134
- ]
135
- models = gr.CheckboxGroup(
136
- choices=model_options, label="Model List", value=PRESELECTED_LMS
137
- )
138
 
139
  ########## 3. Run ranking ##########
140
 
141
- gr.Markdown("## Rank Language Models")
142
 
143
  gr.Markdown(
144
- "Rank models by transferability to your downstream task. "
145
- "Adjust the metric and layer aggregation in Advanced Settings."
146
  )
147
 
148
  with gr.Group():
@@ -195,15 +188,13 @@ with gr.Blocks(css=CSS, theme=None) as demo:
195
  ):
196
 
197
  if text_column == UNSET:
198
- raise gr.Error("Text column is not set.")
199
 
200
  if label_column == UNSET:
201
- raise gr.Error("Label column is not set.")
202
 
203
  if task_category == UNSET:
204
- raise gr.Error(
205
- "Task category not set. Dataset must support classification or regression."
206
- )
207
 
208
  if text_pair_column == UNSET:
209
  text_pair_column = None
@@ -229,23 +220,19 @@ with gr.Blocks(css=CSS, theme=None) as demo:
229
  tracker=tracker,
230
  )
231
 
232
- sorted_results = sorted(
233
- results._results.items(), key=lambda item: item[1], reverse=True
234
- )
235
- return [
236
- (i + 1, model, score) for i, (model, score) in enumerate(sorted_results)
237
- ]
238
  except Exception as e:
239
- print(e)
240
  gr.Warning(f"Ranking issue: {e}")
241
  return []
242
 
243
- gr.Markdown("Ranking table → higher scores indicate better downstream performance.")
244
 
245
  ranking_results = gr.Dataframe(
246
  headers=["Rank", "Model", "Score"],
247
  datatype=["number", "str", "number"],
248
- value=[["-", "-", "-"]]
 
249
  )
250
 
251
  submit_button.click(
@@ -266,11 +253,3 @@ with gr.Blocks(css=CSS, theme=None) as demo:
266
  )
267
 
268
  gr.Markdown(FOOTER)
269
-
270
- if __name__ == "__main__":
271
-
272
- # run up to 3 requests at once
273
- demo.queue(default_concurrency_limit=3)
274
-
275
- # run with 6 workers
276
- demo.launch(max_threads=6)
 
2
  from datasets import disable_caching, load_dataset
3
  from transformer_ranker import TransformerRanker
4
 
5
+ from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS
6
  from demo.utils import (
7
+ BANNER,
8
+ FOOTER,
9
+ CSS,
10
+ UNSET,
11
+ EmbeddingProgressTracker,
12
+ compute_ratio,
13
+ validate_dataset,
14
+ preprocess_dataset,
15
+ ensure_dataset_is_loaded,
16
  )
17
 
18
 
 
24
 
25
  ##### 1. Load from datasets #####
26
 
27
+ gr.Markdown("## 📚 Load Data")
28
 
29
  gr.Markdown(
30
+ "Pick a dataset from the Hugging Face Hub (e.g. `trec`). This defines your downstream task."
 
31
  )
32
 
33
  with gr.Group():
34
  dataset = gr.State(None)
35
 
36
  dataset_id = gr.Textbox(
37
+ label="Dataset identifier",
38
  placeholder="try: trec, conll2003, ag_news",
39
  max_lines=1,
40
  )
41
 
42
+ load_dataset_button = gr.Button(
43
+ value="Load data",
44
+ variant="primary",
45
+ interactive=True,
46
+ )
47
 
48
  # enable loading if dataset exists on hub
49
  dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button)
50
 
51
  gr.Markdown(
52
+ "⚡️ Speed mode on: tweak the downsampling ratio in *Dataset Setup* for quicker runs. "
53
+ "Unlock the full data via [framework](https://github.com/flairNLP/transformer-ranker)."
 
54
  )
55
 
56
  ##### data preprocessing #####
 
60
  dataset_id_label = gr.Label("", label="Dataset")
61
  num_samples = gr.State(0)
62
  num_samples_label = gr.Label("", label="Dataset size")
63
+ num_samples.change(lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label])
 
 
64
 
65
  with gr.Row():
66
  text_column = gr.Dropdown("", label="Text Column")
 
72
 
73
  with gr.Group():
74
  downsample_ratio = gr.State(0.0)
75
+ sampling_rate = gr.Slider(20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1)
 
 
76
  downsample_ratio_label = gr.Label("", label="Sampling rate")
77
  downsample_ratio.change(
78
  lambda x: f"{x:.1%}",
 
91
  outputs=downsample_ratio,
92
  )
93
 
 
94
  def load_hf_dataset(dataset_id):
95
  try:
96
  dataset = load_dataset(dataset_id, trust_remote_code=True)
97
  dataset_details = preprocess_dataset(dataset)
98
  except ValueError as e:
99
+ gr.Warning(f"Watch out — single datasets only. Cannot load dataset: {e}")
100
 
101
+ return (gr.update(value="Loaded"), dataset_id, dataset, *dataset_details)
 
 
 
 
 
102
 
103
  load_dataset_button.click(
104
  load_hf_dataset,
 
118
 
119
  ########## 2. Select LMs ##########
120
 
121
+ gr.Markdown("## 🧠 Select Language Models")
122
 
123
  gr.Markdown(
124
+ "Add two or more pretrained models to compare. "
125
+ "Stick to smaller models here since the demo runs on CPU."
126
  )
127
 
128
  with gr.Group():
129
+ model_options = [(model_handle.split("/")[-1], model_handle) for model_handle in ALL_LMS]
130
+ models = gr.CheckboxGroup(choices=model_options, label="Model List", value=PRESELECTED_LMS)
 
 
 
 
 
131
 
132
  ########## 3. Run ranking ##########
133
 
134
+ gr.Markdown("## 🏆 Rank Models")
135
 
136
  gr.Markdown(
137
+ "Rank models by transferability to your task. "
138
+ "More control? Tweak transferability metric and layer aggregation in *Settings*."
139
  )
140
 
141
  with gr.Group():
 
188
  ):
189
 
190
  if text_column == UNSET:
191
+ raise gr.Error("Text column is required.")
192
 
193
  if label_column == UNSET:
194
+ raise gr.Error("Label column is required.")
195
 
196
  if task_category == UNSET:
197
+ raise gr.Error("Task category is required.")
 
 
198
 
199
  if text_pair_column == UNSET:
200
  text_pair_column = None
 
220
  tracker=tracker,
221
  )
222
 
223
+ sorted_results = sorted(results._results.items(), key=lambda item: item[1], reverse=True)
224
+ return [(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)]
 
 
 
 
225
  except Exception as e:
 
226
  gr.Warning(f"Ranking issue: {e}")
227
  return []
228
 
229
+ gr.Markdown("**Leaderboard:** higher score better downstream performance.")
230
 
231
  ranking_results = gr.Dataframe(
232
  headers=["Rank", "Model", "Score"],
233
  datatype=["number", "str", "number"],
234
+ value=[["-", "-", "-"]],
235
+ interactive=False
236
  )
237
 
238
  submit_button.click(
 
253
  )
254
 
255
  gr.Markdown(FOOTER)
 
 
 
 
 
 
 
 
demo/utils.py CHANGED
@@ -9,11 +9,11 @@ from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
9
  from transformer_ranker.embedder import Embedder
10
 
11
  BANNER = """
12
- <h1 align="center">🔥 TransformerRanker 🔥</h1>
13
 
14
  <p align="center" style="max-width: 560px; margin: auto;">
15
  Find the best language model for your downstream task.
16
- Load a dataset, select models from the 🤗 Hub, and rank them by <strong>transferability</strong>.
17
  </p>
18
 
19
  <p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
@@ -35,8 +35,9 @@ BANNER = """
35
  """
36
 
37
  FOOTER = """
38
- **Note:** CPU-only quick demo. **Built by:** @lukasgarbas & @plonerma
39
- **Questions?** Open a [GitHub issue](https://github.com/flairNLP/transformer-ranker/issues) 🔫.
 
40
  """
41
 
42
  CSS = """
@@ -53,18 +54,18 @@ preprocessing = DatasetCleaner()
53
 
54
 
55
  def validate_dataset(dataset_name):
56
- """Enable if dataset exists on Hub."""
57
  try:
58
- hf_api.dataset_info(dataset_name) # quick dataset info call
59
  return gr.update(interactive=True)
60
-
61
  except (HTTPError, HFValidationError):
62
  return gr.update(value="Load data", interactive=False)
63
 
64
 
65
  def preprocess_dataset(dataset):
66
- """Use data preprocessing to find text/label columns and task category."""
67
  data = concatenate_datasets(list(dataset.values()))
 
68
 
69
  try:
70
  text_column = preprocessing._find_column(data, "text column")
@@ -83,19 +84,11 @@ def preprocess_dataset(dataset):
83
  try:
84
  task_category = preprocessing._find_task_category(data, label_column)
85
  except ValueError:
86
- gr.Warning("Task category not auto-detected — framework supports classification, regression.")
87
-
88
- text_column = gr.update(value=text_column, choices=data.column_names, interactive=True)
89
- label_column = gr.update(value=label_column, choices=data.column_names, interactive=True)
90
- text_pair = gr.update(value=UNSET, choices=[UNSET, *data.column_names], interactive=True)
91
- task_category = gr.update(value=task_category, choices=[str(t) for t in TaskCategory], interactive=True)
92
- sample_size = len(data)
93
-
94
- return task_category, text_column, text_pair, label_column, sample_size
95
 
96
- """
97
  return (
98
- text_column,
99
  gr.update(
100
  value=task_category,
101
  choices=[str(t) for t in TaskCategory],
@@ -110,9 +103,9 @@ def preprocess_dataset(dataset):
110
  gr.update(
111
  value=label_column, choices=data.column_names, interactive=True
112
  ),
113
- num_samples,
114
  )
115
- """
116
 
117
  def compute_ratio(num_samples_to_use, num_samples):
118
  if num_samples > 0:
@@ -128,12 +121,6 @@ def ensure_dataset_is_loaded(dataset, text_column, label_column, task_category):
128
  return gr.update(interactive=False)
129
 
130
 
131
- def ensure_one_lm_selected(checkbox_values, previous_values):
132
- if not any(checkbox_values):
133
- return previous_values
134
- return checkbox_values
135
-
136
-
137
  # apply monkey patch to enable callbacks
138
  _old_embed = Embedder.embed
139
 
@@ -185,8 +172,6 @@ class EmbeddingProgressTracker:
185
  self.progress_bar(1.0, desc="Done")
186
  else:
187
  self.progress_bar(1.0, desc="Error")
188
-
189
- # Do not suppress any errors
190
  return False
191
 
192
  def update_num_batches(self, total):
@@ -201,7 +186,6 @@ class EmbeddingProgressTracker:
201
 
202
  def update_bar(self):
203
  i = self.current_model
204
-
205
  description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"
206
 
207
  progress = i / self.total
 
9
  from transformer_ranker.embedder import Embedder
10
 
11
  BANNER = """
12
+ <h1 align="center">🚀 TransformerRanker ⚡️</h1>
13
 
14
  <p align="center" style="max-width: 560px; margin: auto;">
15
  Find the best language model for your downstream task.
16
+ Load a dataset, pick models from the 🤗 Hub, and rank them by <strong>transferability</strong>.
17
  </p>
18
 
19
  <p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
 
35
  """
36
 
37
  FOOTER = """
38
+ **Note:** Quick CPU-only demo.
39
+ **Built by** [@lukasgarbas](https://huggingface.co/lukasgarbas) & [@plonerma](https://huggingface.co/plonerma)
40
+ **Questions?** Open a [GitHub issue](https://github.com/flairNLP/transformer-ranker/issues) 🔫
41
  """
42
 
43
  CSS = """
 
54
 
55
 
56
  def validate_dataset(dataset_name):
57
+ """Quick look dataset existence on Hub."""
58
  try:
59
+ hf_api.dataset_info(dataset_name)
60
  return gr.update(interactive=True)
 
61
  except (HTTPError, HFValidationError):
62
  return gr.update(value="Load data", interactive=False)
63
 
64
 
65
  def preprocess_dataset(dataset):
66
+ """Detect text/label columns and task type."""
67
  data = concatenate_datasets(list(dataset.values()))
68
+ sample_size = len(data)
69
 
70
  try:
71
  text_column = preprocessing._find_column(data, "text column")
 
84
  try:
85
  task_category = preprocessing._find_task_category(data, label_column)
86
  except ValueError:
87
+ gr.Warning(
88
+ "Task category not auto-detected — framework supports classification, regression."
89
+ )
 
 
 
 
 
 
90
 
 
91
  return (
 
92
  gr.update(
93
  value=task_category,
94
  choices=[str(t) for t in TaskCategory],
 
103
  gr.update(
104
  value=label_column, choices=data.column_names, interactive=True
105
  ),
106
+ sample_size,
107
  )
108
+
109
 
110
  def compute_ratio(num_samples_to_use, num_samples):
111
  if num_samples > 0:
 
121
  return gr.update(interactive=False)
122
 
123
 
 
 
 
 
 
 
124
  # apply monkey patch to enable callbacks
125
  _old_embed = Embedder.embed
126
 
 
172
  self.progress_bar(1.0, desc="Done")
173
  else:
174
  self.progress_bar(1.0, desc="Error")
 
 
175
  return False
176
 
177
  def update_num_batches(self, total):
 
186
 
187
  def update_bar(self):
188
  i = self.current_model
 
189
  description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"
190
 
191
  progress = i / self.total
requirements.txt DELETED
@@ -1,5 +0,0 @@
1
- gradio>=5.0
2
- transformer-ranker==0.1.2
3
- transformers==4.41.0
4
- datasets==3.6
5
- protobuf
 
 
 
 
 
 
runtime.txt DELETED
@@ -1 +0,0 @@
1
- python-3.12