Spaces:
Running
Running
import math | |
import gradio as gr | |
from datasets import concatenate_datasets | |
from huggingface_hub import HfApi | |
from huggingface_hub.errors import HFValidationError | |
from requests.exceptions import HTTPError | |
from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory | |
from transformer_ranker.embedder import Embedder | |
BANNER = """ | |
# 🌐 TransformerRanker ⚡️ | |
Find the best language model for your downstream task. | |
Load a dataset, pick models from the 🤗 Hub, and rank them by **transferability**. | |
[](https://github.com/flairNLP/transformer-ranker) | |
[](https://opensource.org/licenses/MIT) | |
[](https://pypi.org/project/transformer-ranker/) | |
[](https://github.com/flairNLP/transformer-ranker/blob/main/docs/01-walkthrough.md) | |
Developed at [Humboldt University of Berlin](https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/). | |
""" | |
FOOTER = """ | |
**Note:** Quick CPU-only demo. | |
**Built by** [@lukasgarbas](https://huggingface.co/lukasgarbas) & [@plonerma](https://huggingface.co/plonerma) | |
**Questions?** Open a [GitHub issue](https://github.com/flairNLP/transformer-ranker/issues) 🔫 | |
""" | |
CSS = """ | |
.gradio-container { | |
max-width: 800px; | |
margin: auto; | |
} | |
.banner { | |
text-align: center; | |
} | |
.banner img { | |
display: inline-block; | |
} | |
""" | |
UNSET = "-" | |
hf_api = HfApi() | |
preprocessing = DatasetCleaner() | |
def validate_dataset(dataset_name): | |
"""Quick look dataset existence on Hub.""" | |
try: | |
hf_api.dataset_info(dataset_name) | |
return gr.update(interactive=True) | |
except (HTTPError, HFValidationError): | |
return gr.update(value="Load data", interactive=False) | |
def preprocess_dataset(dataset): | |
"""Detect text/label columns and task type.""" | |
data = concatenate_datasets(list(dataset.values())) | |
sample_size = len(data) | |
try: | |
text_column = preprocessing._find_column(data, "text column") | |
except ValueError: | |
gr.Warning("Text column not auto-detected — select in settings.") | |
text_column = UNSET | |
try: | |
label_column = preprocessing._find_column(data, "label column") | |
except ValueError: | |
gr.Warning("Label column not auto-detected — select in settings.") | |
label_column = UNSET | |
task_category = UNSET | |
if label_column != UNSET: | |
try: | |
task_category = preprocessing._find_task_category(data, label_column) | |
except ValueError: | |
gr.Warning( | |
"Task category not auto-detected — framework supports classification, regression." | |
) | |
return ( | |
gr.update( | |
value=task_category, | |
choices=[str(t) for t in TaskCategory], | |
interactive=True, | |
), | |
gr.update( | |
value=text_column, choices=data.column_names, interactive=True | |
), | |
gr.update( | |
value=UNSET, choices=[UNSET, *data.column_names], interactive=True | |
), | |
gr.update( | |
value=label_column, choices=data.column_names, interactive=True | |
), | |
sample_size, | |
) | |
def compute_ratio(num_samples_to_use, num_samples): | |
if num_samples > 0: | |
return num_samples_to_use / num_samples | |
else: | |
return 0.0 | |
def ensure_dataset_is_loaded(dataset, text_column, label_column, task_category): | |
if dataset and text_column != UNSET and label_column != UNSET and task_category != UNSET: | |
return gr.update(interactive=True) | |
else: | |
return gr.update(interactive=False) | |
# apply monkey patch to enable callbacks | |
_old_embed = Embedder.embed | |
def _new_embed(embedder, sentences, batch_size: int = 32, **kw): | |
if embedder.tracker is not None: | |
embedder.tracker.update_num_batches(math.ceil(len(sentences) / batch_size)) | |
return _old_embed(embedder, sentences, batch_size=batch_size, **kw) | |
Embedder.embed = _new_embed | |
_old_embed_batch = Embedder.embed_batch | |
def _new_embed_batch(embedder, *args, **kw): | |
r = _old_embed_batch(embedder, *args, **kw) | |
if embedder.tracker is not None: | |
embedder.tracker.update_batch_complete() | |
return r | |
Embedder.embed_batch = _new_embed_batch | |
_old_init = Embedder.__init__ | |
def _new_init(embedder, *args, tracker=None, **kw): | |
_old_init(embedder, *args, **kw) | |
embedder.tracker = tracker | |
Embedder.__init__ = _new_init | |
class EmbeddingProgressTracker: | |
def __init__(self, *, progress, model_names): | |
self.model_names = model_names | |
self.progress_bar = progress | |
def total(self): | |
return len(self.model_names) | |
def __enter__(self): | |
self.progress_bar = gr.Progress(track_tqdm=False) | |
self.current_model = -1 | |
self.batches_complete = 0 | |
self.batches_total = None | |
return self | |
def __exit__(self, typ, value, tb): | |
if typ is None: | |
self.progress_bar(1.0, desc="Done") | |
else: | |
self.progress_bar(1.0, desc="Error") | |
return False | |
def update_num_batches(self, total): | |
self.current_model += 1 | |
self.batches_complete = 0 | |
self.batches_total = total | |
self.update_bar() | |
def update_batch_complete(self): | |
self.batches_complete += 1 | |
self.update_bar() | |
def update_bar(self): | |
i = self.current_model | |
description = f"Running {self.model_names[i]} ({i + 1} / {self.total})" | |
progress = i / self.total | |
if self.batches_total is not None: | |
progress += (self.batches_complete / self.batches_total) / self.total | |
self.progress_bar(progress=progress, desc=description) | |