Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import gradio as gr | |
| from datasets import disable_caching, load_dataset | |
| from transformer_ranker import TransformerRanker, prepare_popular_models | |
| import traceback | |
| from utils import ( | |
| DISABLED_BUTTON_VARIANT, ENABLED_BUTTON_VARIANT, CSS, HEADLINE, FOOTER, | |
| EmbeddingProgressTracker, check_dataset_exists, check_dataset_is_loaded, | |
| compute_ratio, ensure_one_lm_selected, get_dataset_info | |
| ) | |
| disable_caching() | |
| THEME = "pseudolab/huggingface-korea-theme" | |
| DEFAULT_SAMPLES = 1000 | |
| MAX_SAMPLES = 5000 | |
| LANGUAGE_MODELS = prepare_popular_models('base') + prepare_popular_models('large') | |
| # Add a tiny model for demonstration on CPU | |
| LANGUAGE_MODELS = ['prajjwal1/bert-tiny'] + list(dict.fromkeys(LANGUAGE_MODELS)) | |
| LANGUAGE_MODELS.insert(LANGUAGE_MODELS.index("bert-base-cased") + 1, "bert-base-uncased") | |
| # Preselect some small models | |
| DEFAULT_MODELS = [ | |
| "prajjwal1/bert-tiny", "google/electra-small-discriminator", | |
| "distilbert-base-cased", "sentence-transformers/all-MiniLM-L12-v2" | |
| ] | |
| with gr.Blocks(css=CSS, theme=THEME) as demo: | |
| ########## STEP 1: Load the Dataset ########## | |
| gr.Markdown(HEADLINE) | |
| gr.Markdown("## Step 1: Load a Dataset") | |
| with gr.Group(): | |
| dataset = gr.State(None) | |
| dataset_name = gr.Textbox( | |
| label="Enter the name of your dataset", | |
| placeholder="Examples: trec, ag_news, sst2, conll2003, leondz/wnut_17", | |
| max_lines=1, | |
| ) | |
| select_dataset_button = gr.Button( | |
| value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT | |
| ) | |
| # Activate the "Load dataset" button if dataset was found | |
| dataset_name.change( | |
| check_dataset_exists, inputs=dataset_name, outputs=select_dataset_button | |
| ) | |
| gr.Markdown( | |
| "*The number of samples that can be used in this demo is limited to save resources. " | |
| "To run an estimate on the full dataset, check out the " | |
| "[library](https://github.com/flairNLP/transformer-ranker).*" | |
| ) | |
| ########## Step 1.1 Dataset preprocessing ########## | |
| with gr.Accordion("Dataset settings", open=False) as dataset_config: | |
| with gr.Row() as dataset_details: | |
| dataset_name_label = gr.Label("", label="Dataset Name") | |
| num_samples = gr.State(0) | |
| num_samples_label = gr.Label("", label="Number of Samples") | |
| num_samples.change( | |
| lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label] | |
| ) | |
| with gr.Row(): | |
| text_column = gr.Dropdown("", label="Text Column") | |
| text_pair_column = gr.Dropdown("", label="Text Pair Column") | |
| with gr.Row(): | |
| label_column = gr.Dropdown("", label="Label Column") | |
| task_category = gr.Dropdown("", label="Task Type") | |
| with gr.Group(): | |
| downsample_ratio = gr.State(0.0) | |
| num_samples_to_use = gr.Slider( | |
| 20, MAX_SAMPLES, label="Samples to use", value=DEFAULT_SAMPLES, step=1 | |
| ) | |
| downsample_ratio_label = gr.Label("", label="Ratio of dataset to use") | |
| downsample_ratio.change( | |
| lambda x: f"{x:.1%}", | |
| inputs=[downsample_ratio], | |
| outputs=[downsample_ratio_label], | |
| ) | |
| num_samples_to_use.change( | |
| compute_ratio, | |
| inputs=[num_samples_to_use, num_samples], | |
| outputs=downsample_ratio, | |
| ) | |
| num_samples.change( | |
| compute_ratio, | |
| inputs=[num_samples_to_use, num_samples], | |
| outputs=downsample_ratio, | |
| ) | |
| # Download the dataset and show details | |
| def select_dataset(dataset_name): | |
| try: | |
| dataset = load_dataset(dataset_name, trust_remote_code=True) | |
| dataset_info = get_dataset_info(dataset) | |
| except ValueError: | |
| gr.Warning("Dataset collections are not supported. Please use a single dataset.") | |
| return ( | |
| gr.update(value="Loaded", interactive=False, variant=DISABLED_BUTTON_VARIANT), | |
| gr.Accordion(open=True), | |
| dataset_name, | |
| dataset, | |
| *dataset_info | |
| ) | |
| select_dataset_button.click( | |
| select_dataset, | |
| inputs=[dataset_name], | |
| outputs=[ | |
| select_dataset_button, | |
| dataset_config, | |
| dataset_name_label, | |
| dataset, | |
| task_category, | |
| text_column, | |
| text_pair_column, | |
| label_column, | |
| num_samples, | |
| ], | |
| scroll_to_output=True, | |
| ) | |
| ########## STEP 2 ########## | |
| gr.Markdown("## Step 2: Select a List of Language Models") | |
| with gr.Group(): | |
| model_options = [ | |
| (model_handle.split("/")[-1], model_handle) | |
| for model_handle in LANGUAGE_MODELS | |
| ] | |
| models = gr.CheckboxGroup( | |
| choices=model_options, label="Select Models", value=DEFAULT_MODELS | |
| ) | |
| ########## STEP 3: Run Language Model Ranking ########## | |
| gr.Markdown("## Step 3: Rank LMs") | |
| with gr.Group(): | |
| with gr.Accordion("Advanced settings", open=False): | |
| with gr.Row(): | |
| estimator = gr.Dropdown( | |
| choices=["hscore", "logme", "knn"], | |
| label="Transferability metric", | |
| value="hscore", | |
| ) | |
| layer_pooling_options = ["lastlayer", "layermean", "bestlayer"] | |
| layer_pooling = gr.Dropdown( | |
| choices=["lastlayer", "layermean", "bestlayer"], | |
| label="Layer pooling", | |
| value="layermean", | |
| ) | |
| submit_button = gr.Button("Run Ranking", interactive=False, variant=DISABLED_BUTTON_VARIANT) | |
| # Make button active if the dataset is loaded | |
| dataset.change( | |
| check_dataset_is_loaded, | |
| inputs=[dataset, text_column, label_column, task_category], | |
| outputs=submit_button | |
| ) | |
| label_column.change( | |
| check_dataset_is_loaded, | |
| inputs=[dataset, text_column, label_column, task_category], | |
| outputs=submit_button | |
| ) | |
| text_column.change( | |
| check_dataset_is_loaded, | |
| inputs=[dataset, text_column, label_column, task_category], | |
| outputs=submit_button | |
| ) | |
| def rank_models( | |
| dataset, | |
| downsample_ratio, | |
| selected_models, | |
| layer_pooling, | |
| estimator, | |
| text_column, | |
| text_pair_column, | |
| label_column, | |
| task_category, | |
| progress=gr.Progress(), | |
| ): | |
| if text_column == "-": | |
| raise gr.Error("Text column is not set.") | |
| if label_column == "-": | |
| raise gr.Error("Label column is not set.") | |
| if task_category == "-": | |
| raise gr.Error( | |
| "Task category is not set. The dataset must support classification or regression tasks." | |
| ) | |
| if text_pair_column == "-": | |
| text_pair_column = None | |
| progress(0.0, "Starting") | |
| with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker: | |
| try: | |
| ranker = TransformerRanker( | |
| dataset, | |
| dataset_downsample=downsample_ratio, | |
| text_column=text_column, | |
| text_pair_column=text_pair_column, | |
| label_column=label_column, | |
| task_category=task_category, | |
| ) | |
| results = ranker.run( | |
| models=selected_models, | |
| layer_aggregator=layer_pooling, | |
| estimator=estimator, | |
| batch_size=64, | |
| tracker=tracker, | |
| ) | |
| sorted_results = sorted( | |
| results._results.items(), key=lambda item: item[1], reverse=True | |
| ) | |
| return [ | |
| (i + 1, model, score) for i, (model, score) in enumerate(sorted_results) | |
| ] | |
| except Exception as e: | |
| gr.Error("The dataset is not supported.") | |
| gr.Markdown("## Results") | |
| ranking_results = gr.Dataframe( | |
| headers=["Rank", "Model", "Score"], datatype=["number", "str", "number"] | |
| ) | |
| submit_button.click( | |
| rank_models, | |
| inputs=[ | |
| dataset, | |
| downsample_ratio, | |
| models, | |
| layer_pooling, | |
| estimator, | |
| text_column, | |
| text_pair_column, | |
| label_column, | |
| task_category, | |
| ], | |
| outputs=ranking_results, | |
| scroll_to_output=True, | |
| ) | |
| gr.Markdown( | |
| "*The results are ranked by their transferability score, with the most suitable model listed first. " | |
| "This ranking allows focusing on the higher-ranked models for further exploration and fine-tuning.*" | |
| ) | |
| gr.Markdown(FOOTER) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=3) | |
| demo.launch(max_threads=6) | |
