Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import re | |
| import os | |
| import json | |
| import yaml | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import plotnine as p9 | |
| import sys | |
| sys.path.append('./src') | |
| sys.path.append('.') | |
| from huggingface_hub import HfApi | |
| repo_id = "HUBioDataLab/PROBE" | |
| api = HfApi() | |
| from src.about import * | |
| from src.saving_utils import * | |
| from src.vis_utils import * | |
| from src.bin.PROBE import run_probe | |
| # ------------------------------------------------------------------ | |
| # Helper functions -------------------------------------------------- | |
| # ------------------------------------------------------------------ | |
| def add_new_eval( | |
| human_file, | |
| skempi_file, | |
| model_name_textbox: str, | |
| benchmark_types, | |
| similarity_tasks, | |
| function_prediction_aspect, | |
| function_prediction_dataset, | |
| family_prediction_dataset, | |
| save, | |
| ): | |
| """Validate inputs, run evaluation and (optionally) save results.""" | |
| # map the userβfacing labels back to the original codes | |
| try: | |
| benchmark_types_mapped = [benchmark_type_map[b] for b in benchmark_types] | |
| similarity_tasks_mapped = [similarity_tasks_map[s] for s in similarity_tasks] | |
| function_prediction_aspect_mapped = function_prediction_aspect_map[function_prediction_aspect] | |
| family_prediction_dataset_mapped = [family_prediction_dataset_map[f] for f in family_prediction_dataset] | |
| except KeyError as e: | |
| gr.Warning(f"Unrecognized option: {e.args[0]}") | |
| return -1 | |
| # validate inputs | |
| if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None: | |
| gr.Warning("Human representations are required for similarity, family, or function benchmarks!") | |
| return -1 | |
| if 'affinity' in benchmark_types and skempi_file is None: | |
| gr.Warning("SKEMPI representations are required for affinity benchmark!") | |
| return -1 | |
| gr.Info("Your submission is being processedβ¦") | |
| representation_name = model_name_textbox | |
| try: | |
| results = run_probe( | |
| benchmark_types, | |
| representation_name, | |
| human_file, | |
| skempi_file, | |
| similarity_tasks, | |
| function_prediction_aspect, | |
| function_prediction_dataset, | |
| family_prediction_dataset, | |
| ) | |
| except Exception: | |
| gr.Warning("Your submission has not been processed. Please check your representation files!") | |
| return -1 | |
| if save: | |
| save_results(representation_name, benchmark_types, results) | |
| gr.Info("Your submission has been processed and results are saved!") | |
| else: | |
| gr.Info("Your submission has been processed!") | |
| return 0 | |
| def refresh_data(): | |
| """Reβstart the space and pull fresh leaderboard CSVs from the HF Hub.""" | |
| api.restart_space(repo_id=repo_id) | |
| benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"] | |
| for benchmark_type in benchmark_types: | |
| path = f"/tmp/{benchmark_type}_results.csv" | |
| if os.path.exists(path): | |
| os.remove(path) | |
| benchmark_types.remove("leaderboard") | |
| download_from_hub(benchmark_types) | |
| # ------- Leaderboard helpers ----------------------------------------------- | |
| def update_metrics(selected_benchmarks): | |
| updated_metrics = set() | |
| for benchmark in selected_benchmarks: | |
| updated_metrics.update(benchmark_metric_mapping.get(benchmark, [])) | |
| return list(updated_metrics) | |
| def update_leaderboard(selected_methods, selected_metrics): | |
| return build_leaderboard_styler(selected_methods, selected_metrics) | |
| def colour_method_html(name: str) -> str: | |
| """Return the method string wrapped in a coloured <span>. Handles raw names | |
| or markdown links like '[T5](https://β¦)' transparently.""" | |
| colour = color_dict.get(re.sub(r"\[|\]|\(.*?\)", "", name), "black") # strip md link | |
| return f"<span style='color:{colour}; font-weight:bold;'>{name}</span>" | |
| # darkest β lightest green | |
| TOP5_GREENS = ["#006400", "#228B22", "#32CD32", "#7CFC00", "#ADFF2F"] | |
| def shade_top5(col: pd.Series) -> list[str]: | |
| """Return a CSS list for one column: background for ranks 1-5, blank else.""" | |
| if not pd.api.types.is_numeric_dtype(col): | |
| return [""] * len(col) | |
| ranks = col.rank(ascending=False, method="first") | |
| return [ | |
| f"background-color:{TOP5_GREENS[int(r)-1]};" if r <= 5 else "" | |
| for r in ranks | |
| ] | |
| def build_leaderboard_styler(selected_methods=None, selected_metrics=None): | |
| df = get_baseline_df(selected_methods, selected_metrics).round(4) | |
| df = ( | |
| df.sort_values("Method", key=lambda s: s.str.lower()) # A->Z | |
| .reset_index(drop=True) # tidy row index | |
| ) | |
| df["Method"] = df["Method"].apply(colour_method_html) | |
| numeric_cols = [c for c in df.columns if c != "Method"] | |
| styler = ( | |
| df.style | |
| .apply(shade_top5, axis=0, subset=numeric_cols) | |
| .format(precision=4) | |
| ) | |
| return styler | |
| # ------- Visualisation helpers --------------------------------------------- | |
| def generate_plot(benchmark_type, methods_selected, x_metric, y_metric, aspect, dataset, single_metric): | |
| plot_path = benchmark_plot( | |
| benchmark_type, | |
| methods_selected, | |
| x_metric, | |
| y_metric, | |
| aspect, | |
| dataset, | |
| single_metric, | |
| ) | |
| return plot_path | |
| # --------------------------------------------------------------------------- | |
| # Custom CSS for frozen first column and clearer table styles | |
| # --------------------------------------------------------------------------- | |
| CUSTOM_CSS = """ | |
| /* freeze first column */ | |
| #leaderboard-table table tr th:first-child, | |
| #leaderboard-table table tr td:first-child { | |
| position: sticky; | |
| left: 0; | |
| z-index: 2; | |
| /* wider βMethodβ column */ | |
| min-width: 190px; | |
| width: 190px; | |
| white-space: nowrap; | |
| } | |
| /* centre numeric cells */ | |
| #leaderboard-table td:not(:first-child) { | |
| text-align: center; | |
| } | |
| /* scrollable and taller table */ | |
| #leaderboard-table .dataframe-wrap { | |
| max-height: 1200px; | |
| overflow-y: auto; | |
| overflow-x: auto; | |
| } | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # UI definition | |
| # --------------------------------------------------------------------------- | |
| block = gr.Blocks(css=CUSTOM_CSS) | |
| with block: | |
| gr.Markdown(LEADERBOARD_INTRODUCTION) | |
| with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
| # ------------------------------------------------------------------ | |
| # 1οΈβ£ Leaderboard tab | |
| # ------------------------------------------------------------------ | |
| with gr.TabItem("π PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1): | |
| # ββ header ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| gr.Image( | |
| value="./src/data/PROBE_workflow_figure.jpg", | |
| show_label=False, | |
| height=1000, | |
| container=False, | |
| ) | |
| gr.Markdown( | |
| "## For detailed explanations of the metrics and benchmarks, please refer to the π About tab.", | |
| elem_classes="leaderboard-note", | |
| ) | |
| # ββ data prep ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| leaderboard = get_baseline_df(None, None) | |
| method_names = leaderboard["Method"].unique().tolist() | |
| metric_names = leaderboard.columns.tolist(); metric_names.remove("Method") | |
| base_method_names = [m for m in method_names if m in base_methods] | |
| user_method_names = [m for m in method_names if m not in base_methods] | |
| benchmark_metric_mapping = { | |
| "Semantic Similarity Inference": [m for m in metric_names if m.startswith("sim_")], | |
| "Ontology-based Protein Function Prediction": [m for m in metric_names if m.startswith("func")], | |
| "Drug Target Protein Family Classification": [m for m in metric_names if m.startswith("fam_")], | |
| "Protein-Protein Binding Affinity Estimation": [m for m in metric_names if m.startswith("aff_")], | |
| } | |
| # ββ callback helper ββββββββββββββββββββββββββββββββββββββββββ | |
| def update_leaderboard_combined(selected_base, selected_user, selected_metrics): | |
| selected_methods = (selected_base or []) + (selected_user or []) | |
| return build_leaderboard_styler(selected_methods, selected_metrics) | |
| # ββ collapsible selectors ββββββββββββββββββββββββββββββββββββ | |
| with gr.Accordion("π¦ Base Methods", open=False): | |
| leaderboard_method_selector_base = gr.CheckboxGroup( | |
| choices=base_method_names, | |
| label="Base Methods", | |
| value=base_method_names, # β all selected | |
| interactive=True, | |
| ) | |
| with gr.Accordion("π οΈ User-defined Methods", open=False): | |
| leaderboard_method_selector_user = gr.CheckboxGroup( | |
| choices=user_method_names, | |
| label="User Methods", | |
| value=[], # β none selected | |
| interactive=True, | |
| ) | |
| with gr.Accordion("π§ͺ Benchmark Types", open=False): | |
| benchmark_type_selector_lb = gr.CheckboxGroup( | |
| choices=list(benchmark_metric_mapping.keys()), | |
| label="Benchmark Types", | |
| value=list(benchmark_metric_mapping.keys()), # all selected | |
| interactive=True, | |
| ) | |
| with gr.Accordion("π Metrics", open=False): | |
| leaderboard_metric_selector = gr.CheckboxGroup( | |
| choices=metric_names, | |
| label="Select Metrics", | |
| value=metric_names, # β all selected | |
| interactive=True, | |
| ) | |
| # ββ colour / shading legend (unchanged) ββββββββββββββββββββββ | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| """ | |
| ## Method-name colours | |
| <span style='color:green; font-weight:bold; font-size:1.1rem;'>π’β Classical representations | |
| <span style='color:blue; font-weight:bold; font-size:1.1rem;'>π΅β Small-scale Protein LMs | |
| <span style='color:red; font-weight:bold; font-size:1.1rem;'>π΄β Large-scale Protein LMs | |
| <span style='color:orange;font-weight:bold; font-size:1.1rem;'>π β Multimodal Protein LMs | |
| """, | |
| elem_classes="leaderboard-note", | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| """ | |
| ## Metric-cell shading | |
| <span style='background-color:#006400; color:white; padding:0.4rem 0.8rem; border-radius:0.4rem; | |
| font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>1</span> | |
| <span style='background-color:#228B22; color:white; padding:0.4rem 0.8rem; border-radius:0.4rem; | |
| font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>2</span> | |
| <span style='background-color:#32CD32; color:black; padding:0.4rem 0.8rem; border-radius:0.4rem; | |
| font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>3</span> | |
| <span style='background-color:#7CFC00; color:black; padding:0.4rem 0.8rem; border-radius:0.4rem; | |
| font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>4</span> | |
| <span style='background-color:#ADFF2F; color:black; padding:0.4rem 0.8rem; border-radius:0.4rem; | |
| font-size:1.1rem; display:inline-block; text-align:center;'>5</span> | |
| <br> | |
| <span style='font-size:1.1rem;'> top-five scores (darker β better)</span> | |
| """, | |
| elem_classes="leaderboard-note", | |
| ) | |
| # ββ dataframe ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| styler = build_leaderboard_styler(base_method_names, metric_names) | |
| data_component = gr.Dataframe( | |
| value=styler, | |
| headers=["Method"] + metric_names, | |
| type="pandas", | |
| datatype=["markdown"] + ["number"] * len(metric_names), | |
| interactive=False, | |
| elem_id="leaderboard-table", | |
| pinned_columns=1, | |
| max_height=1000, | |
| show_fullscreen_button=True, | |
| ) | |
| gr.Markdown("#### If a method name ends with **^**, it suggests potential suspicions of data leakage related to ***similarity***, ***function***, or ***family*** benchmarks.") | |
| # ββ callbacks ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| leaderboard_method_selector_base.change( | |
| update_leaderboard_combined, | |
| inputs=[leaderboard_method_selector_base, leaderboard_method_selector_user, leaderboard_metric_selector], | |
| outputs=data_component, | |
| ) | |
| leaderboard_method_selector_user.change( | |
| update_leaderboard_combined, | |
| inputs=[leaderboard_method_selector_base, leaderboard_method_selector_user, leaderboard_metric_selector], | |
| outputs=data_component, | |
| ) | |
| leaderboard_metric_selector.change( | |
| update_leaderboard_combined, | |
| inputs=[leaderboard_method_selector_base, leaderboard_method_selector_user, leaderboard_metric_selector], | |
| outputs=data_component, | |
| ) | |
| benchmark_type_selector_lb.change( | |
| lambda selected: update_metrics(selected), | |
| inputs=[benchmark_type_selector_lb], | |
| outputs=leaderboard_metric_selector, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 2οΈβ£ Visualisation tab | |
| # ------------------------------------------------------------------ | |
| with gr.TabItem("π Visualization", elem_id="probe-benchmark-tab-visualization", id=2): | |
| gr.Markdown( | |
| """## **Interactive Visualizations** | |
| Choose a benchmark type; context-specific options will appear.""", | |
| elem_classes="markdown-text", | |
| ) | |
| # ββ benchmark-type selector ββββββββββββββββββββββββββββββββββ | |
| vis_benchmark_type_selector = gr.Dropdown( | |
| choices=list(benchmark_specific_metrics.keys()), | |
| label="π§ͺ Benchmark Type", | |
| value=None, | |
| ) | |
| # ββ metric / dataset selectors (appear contextually) βββββββββ | |
| with gr.Row(): | |
| vis_x_metric_selector = gr.Dropdown(choices=[], label="X-axis Metric", visible=False) | |
| vis_y_metric_selector = gr.Dropdown(choices=[], label="Y-axis Metric", visible=False) | |
| vis_aspect_type_selector = gr.Dropdown(choices=[], label="Aspect", visible=False) | |
| vis_dataset_selector = gr.Dropdown(choices=[], label="Dataset", visible=False) | |
| vis_single_metric_selector = gr.Dropdown(choices=[], label="Metric", visible=False) | |
| # ββ method selectors (two accordions) βββββββββββββββββββββββ | |
| base_method_names = [m for m in method_names if m in base_methods] | |
| user_method_names = [m for m in method_names if m not in base_methods] | |
| with gr.Accordion("π¦ Base methods", open=False): | |
| vis_method_selector_base = gr.CheckboxGroup( | |
| choices=base_method_names, | |
| label="Base Methods", | |
| value=base_method_names, # default: all selected | |
| interactive=True, | |
| ) | |
| with gr.Accordion("π οΈ User-defined methods", open=False): | |
| vis_method_selector_user = gr.CheckboxGroup( | |
| choices=user_method_names, | |
| label="User Methods", | |
| value=[], # default: none selected | |
| interactive=True, | |
| ) | |
| # ββ plot button & output ββββββββββββββββββββββββββββββββββββ | |
| plot_button = gr.Button("Plot") | |
| with gr.Row(show_progress=True, variant='panel'): | |
| plot_output = gr.Image(label="Plot") | |
| gr.Markdown("#### If a method name ends with **^**, it suggests potential suspicions of data leakage related to ***similarity***, ***function***, or ***family*** benchmarks.") | |
| # ββ callbacks βββββββββββββββββββββββββββββββββββββββββββββββ | |
| vis_benchmark_type_selector.change( | |
| update_metric_choices, | |
| inputs=[vis_benchmark_type_selector], | |
| outputs=[ | |
| vis_x_metric_selector, | |
| vis_y_metric_selector, | |
| vis_aspect_type_selector, | |
| vis_dataset_selector, | |
| vis_single_metric_selector, | |
| ], | |
| ) | |
| # combine the two method lists, then call the original helper | |
| plot_button.click( | |
| lambda bt, base_sel, user_sel, xm, ym, asp, ds, sm: generate_plot( | |
| benchmark_type_map.get(bt, bt), | |
| (base_sel or []) + (user_sel or []), # merged method list | |
| xm, ym, asp, ds, sm, | |
| ), | |
| inputs=[ | |
| vis_benchmark_type_selector, | |
| vis_method_selector_base, | |
| vis_method_selector_user, | |
| vis_x_metric_selector, | |
| vis_y_metric_selector, | |
| vis_aspect_type_selector, | |
| vis_dataset_selector, | |
| vis_single_metric_selector, | |
| ], | |
| outputs=[plot_output], | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 3οΈβ£ About tab | |
| # ------------------------------------------------------------------ | |
| with gr.TabItem("π About", elem_id="probe-benchmark-tab-table", id=3): | |
| with gr.Row(): | |
| gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text") | |
| with gr.Row(): | |
| gr.Image( | |
| value="./src/data/PROBE_workflow_figure.jpg", | |
| label="PROBE Workflow Figure", | |
| elem_classes="about-image", | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 4οΈβ£ Submit tab | |
| # ------------------------------------------------------------------ | |
| with gr.TabItem("π Submit here! ", elem_id="probe-benchmark-tab-table", id=4): | |
| with gr.Row(): | |
| gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text") | |
| with gr.Row(): | |
| gr.Markdown("# βοΈβ¨ Submit your model's representation files here!", elem_classes="markdown-text") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_name_textbox = gr.Textbox(label="Method name") | |
| benchmark_types = gr.CheckboxGroup(choices=TASK_INFO, label="Benchmark Types", interactive=True) | |
| similarity_tasks = gr.CheckboxGroup(choices=similarity_tasks_options, label="Similarity Datasets", interactive=True) | |
| function_prediction_aspect = gr.Radio(choices=function_prediction_aspect_options, label="Function Prediction Aspects", interactive=True) | |
| family_prediction_dataset = gr.CheckboxGroup(choices=family_prediction_dataset_options, label="Family Prediction Datasets", interactive=True) | |
| function_dataset = gr.Textbox(label="Function Prediction Datasets", visible=False, value="All_Data_Sets") | |
| save_checkbox = gr.Checkbox(label="Save results for leaderboard and visualization", value=True) | |
| with gr.Row(): | |
| human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath') | |
| skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath') | |
| submit_button = gr.Button("Submit Eval") | |
| submission_result = gr.Markdown() | |
| submit_button.click( | |
| add_new_eval, | |
| inputs=[ | |
| human_file, | |
| skempi_file, | |
| model_name_textbox, | |
| benchmark_types, | |
| similarity_tasks, | |
| function_prediction_aspect, | |
| function_dataset, | |
| family_prediction_dataset, | |
| save_checkbox, | |
| ], | |
| ) | |
| # global refresh + citation --------------------------------------------- | |
| with gr.Row(): | |
| data_run = gr.Button("Refresh") | |
| data_run.click(refresh_data, outputs=[data_component]) | |
| with gr.Accordion("Citation", open=False): | |
| citation_button = gr.Textbox( | |
| value=CITATION_BUTTON_TEXT, | |
| label=CITATION_BUTTON_LABEL, | |
| elem_id="citation-button", | |
| show_copy_button=True, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| block.launch() |