lukasgarbas commited on
Commit
2018b94
·
1 Parent(s): f5ee3a9
.gitattributes CHANGED
@@ -8,6 +8,8 @@
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
 
11
  *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +35,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.lz4 filter=lfs diff=lfs merge=lfs -text
12
+ *.mds filter=lfs diff=lfs merge=lfs -text
13
  *.mlmodel filter=lfs diff=lfs merge=lfs -text
14
  *.model filter=lfs diff=lfs merge=lfs -text
15
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
35
  *.zip filter=lfs diff=lfs merge=lfs -text
36
  *.zst filter=lfs diff=lfs merge=lfs -text
37
  *tfevents* filter=lfs diff=lfs merge=lfs -text
38
+ # Audio files - uncompressed
39
+ *.pcm filter=lfs diff=lfs merge=lfs -text
40
+ *.sam filter=lfs diff=lfs merge=lfs -text
41
+ *.raw filter=lfs diff=lfs merge=lfs -text
42
+ # Audio files - compressed
43
+ *.aac filter=lfs diff=lfs merge=lfs -text
44
+ *.flac filter=lfs diff=lfs merge=lfs -text
45
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
46
+ *.ogg filter=lfs diff=lfs merge=lfs -text
47
+ *.wav filter=lfs diff=lfs merge=lfs -text
48
+ # Image files - uncompressed
49
+ *.bmp filter=lfs diff=lfs merge=lfs -text
50
+ *.gif filter=lfs diff=lfs merge=lfs -text
51
+ *.png filter=lfs diff=lfs merge=lfs -text
52
+ *.tiff filter=lfs diff=lfs merge=lfs -text
53
+ # Image files - compressed
54
+ *.jpg filter=lfs diff=lfs merge=lfs -text
55
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
56
+ *.webp filter=lfs diff=lfs merge=lfs -text
57
+ # Video files - compressed
58
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
59
+ *.webm filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -153,3 +153,9 @@ dmypy.json
153
 
154
  # Cython debug symbols
155
  cython_debug/
 
 
 
 
 
 
 
153
 
154
  # Cython debug symbols
155
  cython_debug/
156
+
157
+ # macOS
158
+ .DS_Store
159
+
160
+ # Ruff
161
+ .ruff_cache/
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
  title: TransformerRanker
3
- emoji: ⚖️
4
  colorFrom: yellow
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.44.1
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: Efficiently find the best-suited LM for your NLP task
12
- ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: TransformerRanker
3
+ emoji: 🎯🧩
4
  colorFrom: yellow
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.44.0
8
+ app_file: demo/app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Efficient LM Ranking for Downstream Tasks
12
+ ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py → demo/app.py RENAMED
@@ -1,127 +1,113 @@
1
  import gradio as gr
2
  from datasets import disable_caching, load_dataset
3
- from transformer_ranker import TransformerRanker, prepare_popular_models
4
- import traceback
5
 
6
- from utils import (
7
- DISABLED_BUTTON_VARIANT, ENABLED_BUTTON_VARIANT, CSS, HEADLINE, FOOTER,
8
- EmbeddingProgressTracker, check_dataset_exists, check_dataset_is_loaded,
9
- compute_ratio, ensure_one_lm_selected, get_dataset_info
 
10
  )
11
 
12
- disable_caching()
13
-
14
- THEME = "pseudolab/huggingface-korea-theme"
15
- DEFAULT_SAMPLES = 1000
16
- MAX_SAMPLES = 5000
17
- LANGUAGE_MODELS = prepare_popular_models('base') + prepare_popular_models('large')
18
 
19
- # Add a tiny model for demonstration on CPU
20
- LANGUAGE_MODELS = ['prajjwal1/bert-tiny'] + list(dict.fromkeys(LANGUAGE_MODELS))
21
- LANGUAGE_MODELS.insert(LANGUAGE_MODELS.index("bert-base-cased") + 1, "bert-base-uncased")
22
 
23
- # Preselect some small models
24
- DEFAULT_MODELS = [
25
- "prajjwal1/bert-tiny", "google/electra-small-discriminator",
26
- "distilbert-base-cased", "sentence-transformers/all-MiniLM-L12-v2"
27
- ]
28
 
 
29
 
30
- with gr.Blocks(css=CSS, theme=THEME) as demo:
31
 
32
- ########## STEP 1: Load the Dataset ##########
33
 
34
- gr.Markdown(HEADLINE)
 
 
 
35
 
36
- gr.Markdown("## Step 1: Load a Dataset")
37
  with gr.Group():
38
  dataset = gr.State(None)
39
 
40
- dataset_name = gr.Textbox(
41
- label="Enter the name of your dataset",
42
- placeholder="Examples: trec, ag_news, sst2, conll2003, leondz/wnut_17",
43
  max_lines=1,
44
  )
45
- select_dataset_button = gr.Button(
46
- value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT
47
- )
48
 
49
- # Activate the "Load dataset" button if dataset was found
50
- dataset_name.change(
51
- check_dataset_exists, inputs=dataset_name, outputs=select_dataset_button
52
- )
53
 
54
  gr.Markdown(
55
- "*The number of samples that can be used in this demo is limited to save resources. "
56
- "To run an estimate on the full dataset, check out the "
57
- "[library](https://github.com/flairNLP/transformer-ranker).*"
58
  )
59
 
60
- ########## Step 1.1 Dataset preprocessing ##########
61
 
62
- with gr.Accordion("Dataset settings", open=False) as dataset_config:
63
  with gr.Row() as dataset_details:
64
- dataset_name_label = gr.Label("", label="Dataset Name")
65
  num_samples = gr.State(0)
66
- num_samples_label = gr.Label("", label="Number of Samples")
67
  num_samples.change(
68
  lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]
69
  )
70
 
71
  with gr.Row():
72
  text_column = gr.Dropdown("", label="Text Column")
73
- text_pair_column = gr.Dropdown("", label="Text Pair Column")
74
 
75
  with gr.Row():
76
- label_column = gr.Dropdown("", label="Label Column")
77
- task_category = gr.Dropdown("", label="Task Type")
78
 
79
  with gr.Group():
80
  downsample_ratio = gr.State(0.0)
81
- num_samples_to_use = gr.Slider(
82
- 20, MAX_SAMPLES, label="Samples to use", value=DEFAULT_SAMPLES, step=1
83
  )
84
- downsample_ratio_label = gr.Label("", label="Ratio of dataset to use")
85
  downsample_ratio.change(
86
  lambda x: f"{x:.1%}",
87
  inputs=[downsample_ratio],
88
  outputs=[downsample_ratio_label],
89
  )
90
 
91
- num_samples_to_use.change(
92
  compute_ratio,
93
- inputs=[num_samples_to_use, num_samples],
94
  outputs=downsample_ratio,
95
  )
96
  num_samples.change(
97
  compute_ratio,
98
- inputs=[num_samples_to_use, num_samples],
99
  outputs=downsample_ratio,
100
  )
101
 
102
- # Download the dataset and show details
103
- def select_dataset(dataset_name):
104
  try:
105
- dataset = load_dataset(dataset_name, trust_remote_code=True)
106
- dataset_info = get_dataset_info(dataset)
107
- except ValueError:
108
- gr.Warning("Dataset collections are not supported. Please use a single dataset.")
109
 
110
  return (
111
- gr.update(value="Loaded", interactive=False, variant=DISABLED_BUTTON_VARIANT),
112
- gr.Accordion(open=True),
113
- dataset_name,
114
  dataset,
115
- *dataset_info
116
  )
117
 
118
- select_dataset_button.click(
119
- select_dataset,
120
- inputs=[dataset_name],
121
  outputs=[
122
- select_dataset_button,
123
- dataset_config,
124
- dataset_name_label,
125
  dataset,
126
  task_category,
127
  text_column,
@@ -132,53 +118,65 @@ with gr.Blocks(css=CSS, theme=THEME) as demo:
132
  scroll_to_output=True,
133
  )
134
 
135
- ########## STEP 2 ##########
 
 
 
 
 
 
 
136
 
137
- gr.Markdown("## Step 2: Select a List of Language Models")
138
  with gr.Group():
139
  model_options = [
140
  (model_handle.split("/")[-1], model_handle)
141
- for model_handle in LANGUAGE_MODELS
142
  ]
143
  models = gr.CheckboxGroup(
144
- choices=model_options, label="Select Models", value=DEFAULT_MODELS
145
  )
146
 
147
- ########## STEP 3: Run Language Model Ranking ##########
148
 
149
- gr.Markdown("## Step 3: Rank LMs")
 
 
 
 
 
150
 
151
  with gr.Group():
152
- with gr.Accordion("Advanced settings", open=False):
 
 
 
153
  with gr.Row():
154
  estimator = gr.Dropdown(
155
  choices=["hscore", "logme", "knn"],
156
  label="Transferability metric",
157
  value="hscore",
158
  )
159
- layer_pooling_options = ["lastlayer", "layermean", "bestlayer"]
160
- layer_pooling = gr.Dropdown(
161
  choices=["lastlayer", "layermean", "bestlayer"],
162
- label="Layer pooling",
163
  value="layermean",
164
  )
165
- submit_button = gr.Button("Run Ranking", interactive=False, variant=DISABLED_BUTTON_VARIANT)
166
 
167
- # Make button active if the dataset is loaded
168
  dataset.change(
169
- check_dataset_is_loaded,
170
  inputs=[dataset, text_column, label_column, task_category],
171
  outputs=submit_button
172
  )
173
 
174
  label_column.change(
175
- check_dataset_is_loaded,
176
  inputs=[dataset, text_column, label_column, task_category],
177
  outputs=submit_button
178
  )
179
 
180
  text_column.change(
181
- check_dataset_is_loaded,
182
  inputs=[dataset, text_column, label_column, task_category],
183
  outputs=submit_button
184
  )
@@ -187,7 +185,7 @@ with gr.Blocks(css=CSS, theme=THEME) as demo:
187
  dataset,
188
  downsample_ratio,
189
  selected_models,
190
- layer_pooling,
191
  estimator,
192
  text_column,
193
  text_pair_column,
@@ -196,18 +194,18 @@ with gr.Blocks(css=CSS, theme=THEME) as demo:
196
  progress=gr.Progress(),
197
  ):
198
 
199
- if text_column == "-":
200
  raise gr.Error("Text column is not set.")
201
 
202
- if label_column == "-":
203
  raise gr.Error("Label column is not set.")
204
 
205
- if task_category == "-":
206
  raise gr.Error(
207
- "Task category is not set. The dataset must support classification or regression tasks."
208
  )
209
 
210
- if text_pair_column == "-":
211
  text_pair_column = None
212
 
213
  progress(0.0, "Starting")
@@ -225,7 +223,7 @@ with gr.Blocks(css=CSS, theme=THEME) as demo:
225
 
226
  results = ranker.run(
227
  models=selected_models,
228
- layer_aggregator=layer_pooling,
229
  estimator=estimator,
230
  batch_size=64,
231
  tracker=tracker,
@@ -238,11 +236,16 @@ with gr.Blocks(css=CSS, theme=THEME) as demo:
238
  (i + 1, model, score) for i, (model, score) in enumerate(sorted_results)
239
  ]
240
  except Exception as e:
241
- gr.Error("The dataset is not supported.")
 
 
 
 
242
 
243
- gr.Markdown("## Results")
244
  ranking_results = gr.Dataframe(
245
- headers=["Rank", "Model", "Score"], datatype=["number", "str", "number"]
 
 
246
  )
247
 
248
  submit_button.click(
@@ -251,7 +254,7 @@ with gr.Blocks(css=CSS, theme=THEME) as demo:
251
  dataset,
252
  downsample_ratio,
253
  models,
254
- layer_pooling,
255
  estimator,
256
  text_column,
257
  text_pair_column,
@@ -262,13 +265,12 @@ with gr.Blocks(css=CSS, theme=THEME) as demo:
262
  scroll_to_output=True,
263
  )
264
 
265
- gr.Markdown(
266
- "*The results are ranked by their transferability score, with the most suitable model listed first. "
267
- "This ranking allows focusing on the higher-ranked models for further exploration and fine-tuning.*"
268
- )
269
-
270
  gr.Markdown(FOOTER)
271
 
272
  if __name__ == "__main__":
 
 
273
  demo.queue(default_concurrency_limit=3)
 
 
274
  demo.launch(max_threads=6)
 
1
  import gradio as gr
2
  from datasets import disable_caching, load_dataset
3
+ from transformer_ranker import TransformerRanker
 
4
 
5
+ from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS, GRADIO_THEME
6
+ from demo.utils import (
7
+ BANNER, FOOTER, CSS, UNSET,
8
+ EmbeddingProgressTracker, compute_ratio,
9
+ validate_dataset, preprocess_dataset, ensure_dataset_is_loaded
10
  )
11
 
 
 
 
 
 
 
12
 
13
+ disable_caching()
 
 
14
 
15
+ with gr.Blocks(css=CSS, theme=None) as demo:
 
 
 
 
16
 
17
+ gr.Markdown(BANNER)
18
 
19
+ ##### 1. Load from datasets #####
20
 
21
+ gr.Markdown("## Load Downstream Dataset")
22
 
23
+ gr.Markdown(
24
+ "Select a dataset from the Hugging Face Hub such as `trec`. "
25
+ "This defines your downstream task."
26
+ )
27
 
 
28
  with gr.Group():
29
  dataset = gr.State(None)
30
 
31
+ dataset_id = gr.Textbox(
32
+ label="Dataset name",
33
+ placeholder="try: trec, conll2003, ag_news",
34
  max_lines=1,
35
  )
 
 
 
36
 
37
+ load_dataset_button = gr.Button(value="Load data", variant="primary", interactive=True,)
38
+
39
+ # enable loading if dataset exists on hub
40
+ dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button)
41
 
42
  gr.Markdown(
43
+ "Settings auto-configured. "
44
+ "Adjust the downsampling ratio in Dataset Setup, "
45
+ "or use the complete dataset with the [framework](https://github.com/flairNLP/transformer-ranker)."
46
  )
47
 
48
+ ##### data preprocessing #####
49
 
50
+ with gr.Accordion("Dataset Setup", open=False) as dataset_config:
51
  with gr.Row() as dataset_details:
52
+ dataset_id_label = gr.Label("", label="Dataset")
53
  num_samples = gr.State(0)
54
+ num_samples_label = gr.Label("", label="Dataset size")
55
  num_samples.change(
56
  lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]
57
  )
58
 
59
  with gr.Row():
60
  text_column = gr.Dropdown("", label="Text Column")
61
+ text_pair_column = gr.Dropdown("", label="Text Pair")
62
 
63
  with gr.Row():
64
+ label_column = gr.Dropdown("", label="Labels")
65
+ task_category = gr.Dropdown("", label="Downstream Task")
66
 
67
  with gr.Group():
68
  downsample_ratio = gr.State(0.0)
69
+ sampling_rate = gr.Slider(
70
+ 20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1
71
  )
72
+ downsample_ratio_label = gr.Label("", label="Sampling rate")
73
  downsample_ratio.change(
74
  lambda x: f"{x:.1%}",
75
  inputs=[downsample_ratio],
76
  outputs=[downsample_ratio_label],
77
  )
78
 
79
+ sampling_rate.change(
80
  compute_ratio,
81
+ inputs=[sampling_rate, num_samples],
82
  outputs=downsample_ratio,
83
  )
84
  num_samples.change(
85
  compute_ratio,
86
+ inputs=[sampling_rate, num_samples],
87
  outputs=downsample_ratio,
88
  )
89
 
90
+ # load and show details
91
+ def load_hf_dataset(dataset_id):
92
  try:
93
+ dataset = load_dataset(dataset_id, trust_remote_code=True)
94
+ dataset_details = preprocess_dataset(dataset)
95
+ except ValueError as e:
96
+ gr.Warning("Collections not supported. Load one dataset only.")
97
 
98
  return (
99
+ gr.update(value="Loaded"),
100
+ dataset_id,
 
101
  dataset,
102
+ *dataset_details
103
  )
104
 
105
+ load_dataset_button.click(
106
+ load_hf_dataset,
107
+ inputs=[dataset_id],
108
  outputs=[
109
+ load_dataset_button,
110
+ dataset_id_label,
 
111
  dataset,
112
  task_category,
113
  text_column,
 
118
  scroll_to_output=True,
119
  )
120
 
121
+ ########## 2. Select LMs ##########
122
+
123
+ gr.Markdown("## Select Language Models")
124
+
125
+ gr.Markdown(
126
+ "Add two or more pretrained models for ranking. "
127
+ "Go with small models since this demo runs on CPU."
128
+ )
129
 
 
130
  with gr.Group():
131
  model_options = [
132
  (model_handle.split("/")[-1], model_handle)
133
+ for model_handle in ALL_LMS
134
  ]
135
  models = gr.CheckboxGroup(
136
+ choices=model_options, label="Model List", value=PRESELECTED_LMS
137
  )
138
 
139
+ ########## 3. Run ranking ##########
140
 
141
+ gr.Markdown("## Rank Language Models")
142
+
143
+ gr.Markdown(
144
+ "Rank models by transferability to your downstream task. "
145
+ "Adjust the metric and layer aggregation in Advanced Settings."
146
+ )
147
 
148
  with gr.Group():
149
+
150
+ submit_button = gr.Button("Run ranking", variant="primary", interactive=False)
151
+
152
+ with gr.Accordion("Advanced Settings", open=False):
153
  with gr.Row():
154
  estimator = gr.Dropdown(
155
  choices=["hscore", "logme", "knn"],
156
  label="Transferability metric",
157
  value="hscore",
158
  )
159
+ layer_aggregator = gr.Dropdown(
 
160
  choices=["lastlayer", "layermean", "bestlayer"],
161
+ label="Layer aggregation",
162
  value="layermean",
163
  )
 
164
 
165
+ # ranking button works after dataset loads
166
  dataset.change(
167
+ ensure_dataset_is_loaded,
168
  inputs=[dataset, text_column, label_column, task_category],
169
  outputs=submit_button
170
  )
171
 
172
  label_column.change(
173
+ ensure_dataset_is_loaded,
174
  inputs=[dataset, text_column, label_column, task_category],
175
  outputs=submit_button
176
  )
177
 
178
  text_column.change(
179
+ ensure_dataset_is_loaded,
180
  inputs=[dataset, text_column, label_column, task_category],
181
  outputs=submit_button
182
  )
 
185
  dataset,
186
  downsample_ratio,
187
  selected_models,
188
+ layer_aggregator,
189
  estimator,
190
  text_column,
191
  text_pair_column,
 
194
  progress=gr.Progress(),
195
  ):
196
 
197
+ if text_column == UNSET:
198
  raise gr.Error("Text column is not set.")
199
 
200
+ if label_column == UNSET:
201
  raise gr.Error("Label column is not set.")
202
 
203
+ if task_category == UNSET:
204
  raise gr.Error(
205
+ "Task category not set. Dataset must support classification or regression."
206
  )
207
 
208
+ if text_pair_column == UNSET:
209
  text_pair_column = None
210
 
211
  progress(0.0, "Starting")
 
223
 
224
  results = ranker.run(
225
  models=selected_models,
226
+ layer_aggregator=layer_aggregator,
227
  estimator=estimator,
228
  batch_size=64,
229
  tracker=tracker,
 
236
  (i + 1, model, score) for i, (model, score) in enumerate(sorted_results)
237
  ]
238
  except Exception as e:
239
+ print(e)
240
+ gr.Warning(f"Ranking issue: {e}")
241
+ return []
242
+
243
+ gr.Markdown("Ranking table → higher scores indicate better downstream performance.")
244
 
 
245
  ranking_results = gr.Dataframe(
246
+ headers=["Rank", "Model", "Score"],
247
+ datatype=["number", "str", "number"],
248
+ value=[["-", "-", "-"]]
249
  )
250
 
251
  submit_button.click(
 
254
  dataset,
255
  downsample_ratio,
256
  models,
257
+ layer_aggregator,
258
  estimator,
259
  text_column,
260
  text_pair_column,
 
265
  scroll_to_output=True,
266
  )
267
 
 
 
 
 
 
268
  gr.Markdown(FOOTER)
269
 
270
  if __name__ == "__main__":
271
+
272
+ # run up to 3 requests at once
273
  demo.queue(default_concurrency_limit=3)
274
+
275
+ # run with 6 workers
276
  demo.launch(max_threads=6)
demo/config.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SAMPLE_SIZE = 1000
2
+ MAX_SAMPLE_SIZE = 5000
3
+ GRADIO_THEME = None
4
+
5
+ ALL_LMS = [
6
+ # tiny
7
+ "prajjwal1/bert-tiny", "arnir0/Tiny-LLM",
8
+ "sentence-transformers/all-MiniLM-L12-v2", "google/electra-small-discriminator",
9
+ "distilbert-base-cased", "typeform/distilroberta-base-v2",
10
+
11
+ # small
12
+ "bert-base-cased", "roberta-base", "google/electra-base-discriminator", "microsoft/deberta-v3-base",
13
+ "KISTI-AI/scideberta", "sentence-transformers/all-mpnet-base-v2", "huggingface/CodeBERTa-small-v1",
14
+ "FacebookAI/xlm-roberta-base", "microsoft/mdeberta-v3-base", "HuggingFaceTB/SmolLM2-135M"
15
+ ]
16
+
17
+ PRESELECTED_LMS = [
18
+ "prajjwal1/bert-tiny",
19
+ "sentence-transformers/all-MiniLM-L12-v2",
20
+ "arnir0/Tiny-LLM",
21
+ "google/electra-small-discriminator",
22
+ ]
utils.py → demo/utils.py RENAMED
@@ -1,118 +1,118 @@
 
 
1
  import gradio as gr
2
  from datasets import concatenate_datasets
3
  from huggingface_hub import HfApi
4
  from huggingface_hub.errors import HFValidationError
5
  from requests.exceptions import HTTPError
6
- from transformer_ranker import Result
7
  from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
8
  from transformer_ranker.embedder import Embedder
9
- import math
10
 
11
- DISABLED_BUTTON_VARIANT = "huggingface"
12
- ENABLED_BUTTON_VARIANT = "primary"
13
 
14
- HEADLINE = """
15
- <h1 align="center">TransformerRanker</h1>
16
  <p align="center" style="max-width: 560px; margin: auto;">
17
- A very simple library that helps you find the best-suited language model for your NLP task.
18
- All you need to do is to select a dataset and a list of pre-trained language models (LMs) from the 🤗 HuggingFace Hub.
19
- TransformerRanker will quickly estimate which of these LMs will perform best on the given dataset!
20
  </p>
 
21
  <p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
22
  <a href="https://github.com/flairNLP/transformer-ranker">
23
- <img src="https://img.shields.io/badge/Repository-black?style=flat&logo=github" alt="GitHub Repo">
 
 
 
24
  </a>
25
  <a href="https://pypi.org/project/transformer-ranker/">
26
- <img src="https://img.shields.io/badge/Package-orange?style=flat&logo=python" alt="Package Badge">
27
  </a>
28
- <a href="https://github.com/flairNLP/transformer-ranker/blob/main/examples/01-walkthrough.md">
29
- <img src="https://img.shields.io/badge/Tutorials-blue?style=flat&logo=readthedocs&logoColor=white" alt="Tutorials Badge">
30
  </a>
31
- <img src="https://img.shields.io/badge/license-MIT-green?style=flat" alt="License: MIT">
32
  </p>
 
33
  <p align="center">Developed at <a href="https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/">Humboldt University of Berlin</a>.</p>
34
  """
35
 
36
  FOOTER = """
37
- **Note:** This demonstration currently runs on a CPU and is suited for smaller models only.
38
- **Developers:** [@plonerma](https://huggingface.co/plonerma) and [@lukasgarbas](https://huggingface.co/lukasgarbas).
39
- For feedback, suggestions, or contributions, reach out via GitHub or leave a message in the [discussions](https://huggingface.co/spaces/lukasgarbas/transformer-ranker/discussions).
40
  """
41
 
42
  CSS = """
43
- .gradio-container{max-width: 800px !important}
44
- a {color: #ff9d00;}
45
- @media (prefers-color-scheme: dark) { a {color: #be185d;} }
 
46
  """
47
 
 
48
 
49
  hf_api = HfApi()
 
50
 
51
 
52
- def check_dataset_exists(dataset_name):
53
- """Update loading button if dataset can be found"""
54
  try:
55
- hf_api.dataset_info(dataset_name)
56
- return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT)
57
 
58
  except (HTTPError, HFValidationError):
59
- return gr.update(value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT)
60
-
61
- def check_dataset_is_loaded(dataset, text_column, label_column, task_category):
62
- if dataset and text_column != "-" and label_column != "-" and task_category != "-":
63
- return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT)
64
- else:
65
- return gr.update(interactive=False, variant=DISABLED_BUTTON_VARIANT)
66
 
67
 
68
- def get_dataset_info(dataset):
69
- """Show information for dataset settings"""
70
- joined_dataset = concatenate_datasets(list(dataset.values()))
71
- datacleaner = DatasetCleaner()
72
 
73
  try:
74
- text_column = datacleaner._find_column(joined_dataset, "text column")
75
  except ValueError:
76
- gr.Warning("Text column can not be found. Select it in the dataset settings.")
77
- text_column = "-"
78
 
79
  try:
80
- label_column = datacleaner._find_column(joined_dataset, "label column")
81
  except ValueError:
82
- gr.Warning("Label column can not be found. Select it in the dataset settings.")
83
- label_column = "-"
84
 
85
- task_category = "-"
86
- if label_column != "-":
87
  try:
88
- # Find or set the task_category
89
- task_category = datacleaner._find_task_category(joined_dataset, label_column)
90
  except ValueError:
91
- gr.Warning(
92
- "Task category could not be determined. The dataset must support classification or regression tasks.",
93
- )
94
- pass
95
 
96
- num_samples = len(joined_dataset)
 
 
 
 
97
 
 
 
 
98
  return (
 
99
  gr.update(
100
  value=task_category,
101
  choices=[str(t) for t in TaskCategory],
102
  interactive=True,
103
  ),
104
  gr.update(
105
- value=text_column, choices=joined_dataset.column_names, interactive=True
106
  ),
107
  gr.update(
108
- value="-", choices=["-", *joined_dataset.column_names], interactive=True
109
  ),
110
  gr.update(
111
- value=label_column, choices=joined_dataset.column_names, interactive=True
112
  ),
113
  num_samples,
114
  )
115
-
116
 
117
  def compute_ratio(num_samples_to_use, num_samples):
118
  if num_samples > 0:
@@ -121,13 +121,20 @@ def compute_ratio(num_samples_to_use, num_samples):
121
  return 0.0
122
 
123
 
 
 
 
 
 
 
 
124
  def ensure_one_lm_selected(checkbox_values, previous_values):
125
  if not any(checkbox_values):
126
  return previous_values
127
  return checkbox_values
128
 
129
 
130
- # Apply monkey patch to enable callbacks
131
  _old_embed = Embedder.embed
132
 
133
  def _new_embed(embedder, sentences, batch_size: int = 32, **kw):
@@ -202,4 +209,3 @@ class EmbeddingProgressTracker:
202
  progress += (self.batches_complete / self.batches_total) / self.total
203
 
204
  self.progress_bar(progress=progress, desc=description)
205
-
 
1
+ import math
2
+
3
  import gradio as gr
4
  from datasets import concatenate_datasets
5
  from huggingface_hub import HfApi
6
  from huggingface_hub.errors import HFValidationError
7
  from requests.exceptions import HTTPError
 
8
  from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
9
  from transformer_ranker.embedder import Embedder
 
10
 
11
+ BANNER = """
12
+ <h1 align="center">🔥 TransformerRanker 🔥</h1>
13
 
 
 
14
  <p align="center" style="max-width: 560px; margin: auto;">
15
+ Find the best language model for your downstream task.
16
+ Load a dataset, select models from the 🤗 Hub, and rank them by <strong>transferability</strong>.
 
17
  </p>
18
+
19
  <p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
20
  <a href="https://github.com/flairNLP/transformer-ranker">
21
+ <img src="https://img.shields.io/badge/Code Repo-black?style=flat&logo=github" alt="repository">
22
+ </a>
23
+ <a href="https://opensource.org/licenses/MIT">
24
+ <img src="https://img.shields.io/badge/License-MIT-brightgreen?style=flat" alt="license">
25
  </a>
26
  <a href="https://pypi.org/project/transformer-ranker/">
27
+ <img src="https://img.shields.io/badge/Package-orange?style=flat&logo=python" alt="package">
28
  </a>
29
+ <a href="https://github.com/flairNLP/transformer-ranker/blob/main/docs/01-walkthrough.md">
30
+ <img src="https://img.shields.io/badge/Tutorials-blue?style=flat&logo=readthedocs&logoColor=white" alt="tutorials">
31
  </a>
 
32
  </p>
33
+
34
  <p align="center">Developed at <a href="https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/">Humboldt University of Berlin</a>.</p>
35
  """
36
 
37
  FOOTER = """
38
+ **Note:** CPU-only quick demo. **Built by:** @lukasgarbas & @plonerma
39
+ **Questions?** Open a [GitHub issue](https://github.com/flairNLP/transformer-ranker/issues) 🔫.
 
40
  """
41
 
42
  CSS = """
43
+ .gradio-container {
44
+ max-width: 800px;
45
+ margin: auto;
46
+ }
47
  """
48
 
49
+ UNSET = "-"
50
 
51
  hf_api = HfApi()
52
+ preprocessing = DatasetCleaner()
53
 
54
 
55
+ def validate_dataset(dataset_name):
56
+ """Enable if dataset exists on Hub."""
57
  try:
58
+ hf_api.dataset_info(dataset_name) # quick dataset info call
59
+ return gr.update(interactive=True)
60
 
61
  except (HTTPError, HFValidationError):
62
+ return gr.update(value="Load data", interactive=False)
 
 
 
 
 
 
63
 
64
 
65
+ def preprocess_dataset(dataset):
66
+ """Use data preprocessing to find text/label columns and task category."""
67
+ data = concatenate_datasets(list(dataset.values()))
 
68
 
69
  try:
70
+ text_column = preprocessing._find_column(data, "text column")
71
  except ValueError:
72
+ gr.Warning("Text column not auto-detected select in settings.")
73
+ text_column = UNSET
74
 
75
  try:
76
+ label_column = preprocessing._find_column(data, "label column")
77
  except ValueError:
78
+ gr.Warning("Label column not auto-detected select in settings.")
79
+ label_column = UNSET
80
 
81
+ task_category = UNSET
82
+ if label_column != UNSET:
83
  try:
84
+ task_category = preprocessing._find_task_category(data, label_column)
 
85
  except ValueError:
86
+ gr.Warning("Task category not auto-detected — framework supports classification, regression.")
 
 
 
87
 
88
+ text_column = gr.update(value=text_column, choices=data.column_names, interactive=True)
89
+ label_column = gr.update(value=label_column, choices=data.column_names, interactive=True)
90
+ text_pair = gr.update(value=UNSET, choices=[UNSET, *data.column_names], interactive=True)
91
+ task_category = gr.update(value=task_category, choices=[str(t) for t in TaskCategory], interactive=True)
92
+ sample_size = len(data)
93
 
94
+ return task_category, text_column, text_pair, label_column, sample_size
95
+
96
+ """
97
  return (
98
+ text_column,
99
  gr.update(
100
  value=task_category,
101
  choices=[str(t) for t in TaskCategory],
102
  interactive=True,
103
  ),
104
  gr.update(
105
+ value=text_column, choices=data.column_names, interactive=True
106
  ),
107
  gr.update(
108
+ value=UNSET, choices=[UNSET, *data.column_names], interactive=True
109
  ),
110
  gr.update(
111
+ value=label_column, choices=data.column_names, interactive=True
112
  ),
113
  num_samples,
114
  )
115
+ """
116
 
117
  def compute_ratio(num_samples_to_use, num_samples):
118
  if num_samples > 0:
 
121
  return 0.0
122
 
123
 
124
+ def ensure_dataset_is_loaded(dataset, text_column, label_column, task_category):
125
+ if dataset and text_column != UNSET and label_column != UNSET and task_category != UNSET:
126
+ return gr.update(interactive=True)
127
+ else:
128
+ return gr.update(interactive=False)
129
+
130
+
131
  def ensure_one_lm_selected(checkbox_values, previous_values):
132
  if not any(checkbox_values):
133
  return previous_values
134
  return checkbox_values
135
 
136
 
137
+ # apply monkey patch to enable callbacks
138
  _old_embed = Embedder.embed
139
 
140
  def _new_embed(embedder, sentences, batch_size: int = 32, **kw):
 
209
  progress += (self.batches_complete / self.batches_total) / self.total
210
 
211
  self.progress_bar(progress=progress, desc=description)
 
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
- gradio>=4.44.1
2
  transformer-ranker==0.1.2
 
 
 
 
1
+ gradio>=5.0
2
  transformer-ranker==0.1.2
3
+ transformers==4.41.0
4
+ datasets==3.6
5
+ protobuf
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.12