Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import re | |
| import os | |
| import json | |
| import yaml | |
| import matplotlib.pyplot as plt | |
| from matplotlib import ticker | |
| import seaborn as sns | |
| import plotnine as p9 | |
| import sys | |
| import numpy as np | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append('..') | |
| sys.path.append('.') | |
| from about import * | |
| from saving_utils import download_from_hub | |
| global data_component, filter_component | |
| def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric, aspect, dataset, single_metric): | |
| if benchmark_type == 'similarity': | |
| return plot_similarity_results(methods_selected, x_metric, y_metric) | |
| elif benchmark_type == 'function': | |
| return plot_function_results(methods_selected, aspect, single_metric) | |
| elif benchmark_type == 'family': | |
| return plot_family_results(methods_selected, dataset) | |
| elif benchmark_type == "affinity": | |
| return plot_affinity_results(methods_selected, single_metric) | |
| else: | |
| return -1 | |
| def get_method_color(method): | |
| return color_dict.get(method, 'black') # If method is not in color_dict, use black | |
| def get_labels_and_title(x_metric, y_metric): | |
| # Define mapping for long forms | |
| long_form_mapping = { | |
| "MF": "Molecular Function", | |
| "BP": "Biological Process", | |
| "CC": "Cellular Component" | |
| } | |
| # Parse the metrics | |
| def parse_metric(metric): | |
| parts = metric.split("_") | |
| dataset = parts[0] # sparse/200/500 | |
| category = parts[1] # MF/BP/CC | |
| measure = parts[2] # pvalue/correlation | |
| return dataset, category, measure | |
| x_dataset, x_category, x_measure = parse_metric(x_metric) | |
| y_dataset, y_category, y_measure = parse_metric(y_metric) | |
| # Determine the title | |
| if x_category == y_category: | |
| title = long_form_mapping[x_category] | |
| else: | |
| title = f"{long_form_mapping[x_category]} (x) vs {long_form_mapping[y_category]} (y)" | |
| # Determine the axis labels | |
| x_label = f"{x_measure.capitalize()} on {x_dataset.capitalize()} Dataset" | |
| y_label = f"{y_measure.capitalize()} on {y_dataset.capitalize()} Dataset" | |
| return title, x_label, y_label | |
| def plot_similarity_results(methods_selected, x_metric, y_metric, similarity_path="/tmp/similarity_results.csv"): | |
| if not os.path.exists(similarity_path): | |
| benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later | |
| download_from_hub(benchmark_types) | |
| similarity_df = pd.read_csv(similarity_path) | |
| # Filter the dataframe based on selected methods | |
| filtered_df = similarity_df[similarity_df['Method'].isin(methods_selected)] | |
| # Replace None or NaN values with 0 in relevant columns | |
| filtered_df = filtered_df.fillna(0) | |
| # Add a new column to the dataframe for the color | |
| filtered_df['color'] = filtered_df['Method'].apply(get_method_color) | |
| title, x_label, y_label = get_labels_and_title(x_metric, y_metric) | |
| adjust_text_dict = { | |
| 'expand_text': (1.15, 1.4), 'expand_points': (1.15, 1.25), 'expand_objects': (1.05, 1.5), | |
| 'expand_align': (1.05, 1.2), 'autoalign': 'xy', 'va': 'center', 'ha': 'center', | |
| 'force_text': (.0, 1.), 'force_objects': (.0, 1.), | |
| 'lim': 500000, 'precision': 1., 'avoid_points': True, 'avoid_text': True | |
| } | |
| # Create the scatter plot using plotnine (ggplot) | |
| g = (p9.ggplot(data=filtered_df, | |
| mapping=p9.aes(x=x_metric, # Use the selected x_metric | |
| y=y_metric, # Use the selected y_metric | |
| color='color', # Use the dynamically generated color | |
| label='Method')) # Label each point by the method name | |
| + p9.geom_point(size=3) # Add points with no jitter, set point size | |
| + p9.geom_text(nudge_y=0.02, size=8) # Add method names as labels, nudge slightly above the points | |
| + p9.labs(title=title, x=x_label, y=y_label) # Dynamic labels for X and Y axes | |
| + p9.scale_color_identity() # Use colors directly from the dataframe | |
| + p9.theme(legend_position='none', | |
| figure_size=(8, 8), # Set figure size | |
| axis_text=p9.element_text(size=10), | |
| axis_title_x=p9.element_text(size=12), | |
| axis_title_y=p9.element_text(size=12)) | |
| ) | |
| # Save the plot as an image | |
| save_path = "/tmp" | |
| filename = os.path.join(save_path, title.replace(" ", "_") + "_Similarity_Scatter.png") | |
| g.save(filename=filename, dpi=400) | |
| return filename | |
| def plot_function_results(method_names, aspect, metric, function_path="/tmp/function_results.csv"): | |
| if not os.path.exists(function_path): | |
| benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later | |
| download_from_hub(benchmark_types) | |
| # Load data | |
| df = pd.read_csv(function_path) | |
| # Filter for selected methods | |
| df = df[df['Method'].isin(method_names)] | |
| # Filter columns for specified aspect and metric | |
| columns_to_plot = [col for col in df.columns if col.startswith(f"{aspect}_") and col.endswith(f"_{metric}")] | |
| df = df[['Method'] + columns_to_plot] | |
| df.set_index('Method', inplace=True) | |
| # Fill missing values with 0 | |
| df = df.fillna(0) | |
| df = df.T | |
| # Generate colors for methods | |
| row_color_dict = {method: get_method_color(method) for method in df.index} | |
| long_form_mapping = { | |
| "MF": "Molecular Function", | |
| "BP": "Biological Process", | |
| "CC": "Cellular Component" | |
| } | |
| # Create clustermap | |
| g = sns.clustermap(df, annot=True, cmap="YlGnBu", row_cluster=False, col_cluster=False, figsize=(15, 15)) | |
| title = f"{long_form_mapping[aspect.upper()]} Results for {metric.capitalize()}" | |
| g.fig.suptitle(title, x=0.5, y=1.02, fontsize=16, ha='center') # Center the title above the plot | |
| # Get heatmap axis and customize labels | |
| ax = g.ax_heatmap | |
| ax.set_xlabel("") | |
| ax.set_ylabel("") | |
| # Save the plot as an image | |
| save_path = "/tmp" | |
| filename = os.path.join(save_path, f"{aspect}_{metric}_heatmap.png") | |
| plt.savefig(filename, dpi=400, bbox_inches='tight') | |
| plt.close() # Close the plot to free memory | |
| return filename | |
| def plot_family_results(method_names, dataset, family_path="/tmp/family_results.csv"): | |
| if not os.path.exists(family_path): | |
| benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later | |
| download_from_hub(benchmark_types) | |
| df = pd.read_csv(family_path) | |
| # Filter by method names and selected dataset columns | |
| df = df[df['Method'].isin(method_names)] | |
| # Filter columns based on the dataset and metrics | |
| value_vars = [col for col in df.columns if col.startswith(f"{dataset}_") and "_" in col] | |
| # Reshape the DataFrame to long format | |
| df_long = pd.melt(df, id_vars=["Method"], value_vars=value_vars, var_name="Dataset_Metric_Fold", value_name="Value") | |
| print(df_long) | |
| # Convert the "Value" column to numeric | |
| df_long["Value"] = pd.to_numeric(df_long["Value"], errors="coerce") | |
| # Drop rows with NaN values in "Value" | |
| df_long = df_long.dropna(subset=["Value"]) | |
| # Split the "Dataset_Metric_Fold" column into "Metric" and "Fold" | |
| df_long[["Metric", "Fold"]] = df_long["Dataset_Metric_Fold"].str[len(dataset) + 1:].str.split("_", expand=True) | |
| df_long["Fold"] = df_long["Fold"].astype(int) | |
| # Set up the plot | |
| sns.set(rc={"figure.figsize": (13.7, 18.27)}) | |
| sns.set_theme(style="whitegrid", color_codes=True) | |
| # Create boxplot | |
| ax = sns.boxplot(data=df_long, x="Value", y="Method", hue="Metric", whis=np.inf, orient="h") | |
| # Customize grid and ticks | |
| ax.xaxis.set_major_locator(ticker.MultipleLocator(0.2)) | |
| ax.xaxis.set_minor_locator(ticker.AutoMinorLocator()) | |
| ax.yaxis.set_minor_locator(ticker.AutoMinorLocator()) | |
| ax.grid(visible=True, which="major", color="gainsboro", linewidth=1.0) | |
| ax.grid(visible=True, which="minor", color="whitesmoke", linewidth=0.5) | |
| ax.set_xlim(0, 1) | |
| # Add dashed lines between methods | |
| yticks = ax.get_yticks() | |
| for ytick in yticks: | |
| ax.hlines(ytick + 0.5, -0.1, 1, linestyles="dashed", color="gray") | |
| # Apply color settings to y-axis labels | |
| for label in ax.get_yticklabels(): | |
| method = label.get_text() | |
| label.set_color(get_method_color(method)) | |
| # Save the plot | |
| save_path = "/tmp" | |
| filename = os.path.join(save_path, f"{dataset}_family_results.png") | |
| ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight') | |
| plt.close() # Close the plot to free memory | |
| return filename | |
| def plot_affinity_results(method_names, metric, affinity_path="/tmp/affinity_results.csv"): | |
| if not os.path.exists(affinity_path): | |
| benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later | |
| download_from_hub(benchmark_types) | |
| df = pd.read_csv(affinity_path) | |
| # Filter for selected methods | |
| df = df[df['Method'].isin(method_names)] | |
| # Gather columns related to the specified metric and validate | |
| metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")] | |
| df = df[['Method'] + metric_columns].set_index('Method') | |
| df = df.fillna(0) | |
| df = df.T | |
| # Set up the plot | |
| sns.set(rc={'figure.figsize': (11.7, 8.27)}) | |
| sns.set_theme(style="whitegrid", color_codes=True) | |
| # Create the boxplot | |
| ax = sns.boxplot(data=df, whis=np.inf, orient="h") | |
| # Add a swarmplot on top of the boxplot | |
| sns.swarmplot(data=df, orient="h", color=".1", ax=ax) | |
| # Set labels and x-axis formatting | |
| ax.set_xlabel("Percent Pearson Correlation") | |
| ax.xaxis.set_major_locator(ticker.MultipleLocator(5)) | |
| ax.xaxis.set_minor_locator(ticker.AutoMinorLocator()) | |
| ax.yaxis.set_minor_locator(ticker.AutoMinorLocator()) | |
| ax.grid(visible=True, which='major', color='gainsboro', linewidth=1.0) | |
| ax.grid(visible=True, which='minor', color='whitesmoke', linewidth=0.5) | |
| # Apply custom color settings to y-axis labels | |
| for label in ax.get_yticklabels(): | |
| method = label.get_text() | |
| label.set_color(get_method_color(method)) | |
| # Add legend | |
| ax.legend(loc='best', frameon=True) | |
| # Save the plot | |
| save_path = "/tmp" | |
| filename = os.path.join(save_path, f"{metric}_affinity_results.png") | |
| ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight') | |
| plt.close() # Close the plot to free memory | |
| return filename | |
| def update_metric_choices(benchmark_type): | |
| if benchmark_type == 'similarity': | |
| # Show x and y metric selectors for similarity | |
| metric_names = benchmark_specific_metrics.get(benchmark_type, []) | |
| return ( | |
| gr.update(choices=metric_names, value=metric_names[0], visible=True), | |
| gr.update(choices=metric_names, value=metric_names[1], visible=True), | |
| gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
| ) | |
| elif benchmark_type == 'function': | |
| # Show aspect and dataset type selectors for function | |
| aspect_types = benchmark_specific_metrics[benchmark_type]['aspect_types'] | |
| metric_types = benchmark_specific_metrics[benchmark_type]['dataset_types'] | |
| return ( | |
| gr.update(visible=False), gr.update(visible=False), | |
| gr.update(choices=aspect_types, value=aspect_types[0], visible=True), | |
| gr.update(visible=False), | |
| gr.update(choices=metric_types, value=metric_types[0], visible=True) | |
| ) | |
| elif benchmark_type == 'family': | |
| # Show dataset and metric selectors for family | |
| datasets = benchmark_specific_metrics[benchmark_type]['datasets'] | |
| metrics = benchmark_specific_metrics[benchmark_type]['metrics'] | |
| return ( | |
| gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), | |
| gr.update(choices=datasets, value=datasets[0], visible=True), | |
| gr.update(visible=False) | |
| ) | |
| elif benchmark_type == 'affinity': | |
| # Show single metric selector for affinity | |
| metrics = benchmark_specific_metrics[benchmark_type] | |
| return ( | |
| gr.update(visible=False), gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), gr.update(choices=metrics, value=metrics[0], visible=True) | |
| ) | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |