import matplotlib.pyplot as plt import matplotlib.patches as mpatches import seaborn as sns import pandas as pd import numpy as np import os import matplotlib.colors as mcolors from fuson_plm.utils.visualizing import set_font fo_puncta_db_training_thresh31 = pd.DataFrame(data={ 'Model Type': ['fo_puncta_ml'], 'Model Name': ['fo_puncta_ml_literature'], 'Model Epoch': np.nan, 'Accuracy': 0.81, 'Precision': 0.78, 'Recall': 0.98, 'F1 Score': 0.87, 'AUROC': 0.88, 'AUPRC': 0.94 }) fo_puncta_db_verification_thresh83 = pd.DataFrame(data={ 'Model Type': ['fo_puncta_ml'], 'Model Name': ['fo_puncta_ml_literature'], 'Model Epoch': np.nan, 'Accuracy': 0.79, 'Precision': 0.81, 'Recall': 0.89, 'F1 Score': 0.85, 'AUROC': 0.73, 'AUPRC': 0.82 }) # Method for lengthening the model name def lengthen_model_name(row): name = row['Model Name'] epoch = row['Model Epoch'] if 'esm' in name: return name if 'puncta' in name: return name return f'{name}_e{epoch}' # Method for shortening the model name for display def shorten_model_name(row): name = row['Model Name'] epoch = row['Model Epoch'] if 'esm' in name: return 'ESM-2-650M' if name=='fo_puncta_ml': return 'FO-Puncta-ML' if name=='fo_puncta_ml_literature': return 'FO-Puncta-ML Lit' if name=="prot_t5_xl_half_uniref50_enc": return 'ProtT5-XL-U50' # this is waht they call it in the paper if 'snp_' in name: prob_type = 'snp' elif 'uniform_' in name: prob_type = 'uni' layers = name.split('layers')[0].split('_')[-1] dt = name.split('mask')[1].split('-', 1)[1] return f'{prob_type}_{layers}L_{dt}_e{epoch}' def make_final_bar(dataframe, title, save_path): set_font() df = dataframe.copy(deep=True) # Pivot the DataFrame to have metrics as rows and names as columns, and reorder columns pivot_df = df.pivot(index='Metric', columns='Name', values='Value') ordered_columns = [x for x in ['FOdb','ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM'] if x in pivot_df.columns] pivot_df = pivot_df[ordered_columns] # Define the groups engineered_embeddings = ['FOdb'] deep_learning_embeddings = ['ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM'] # Reorder the metrics metric_order = ['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'][::-1] pivot_df = pivot_df.reindex(metric_order) # Plotting fig, ax = plt.subplots(figsize=(8, 6), dpi=300) # Increased figure size for better legend placement # Define bar width and positions bar_width = 0.2 indices = np.arange(len(pivot_df)) # Use a colorblind-friendly color scheme from tableau color_map = { #'One-Hot': "#999999", 'FOdb': "#E69F00", 'ESM-2-650M': "#F0E442", 'FusOn-pLM': "#FF69B4", 'ProtT5-XL-U50': "#00ccff" # light blue } colors = [color_map[col] for col in ordered_columns] # Plot bars for each category and add them to appropriate legend groups engineered_handles = [] deep_learning_handles = [] for i, (name, color) in enumerate(zip(pivot_df.columns, colors)): bars = ax.barh(indices + i * bar_width, pivot_df[name], bar_width, label=name, color=color) if name in engineered_embeddings: engineered_handles.append(bars[0]) else: deep_learning_handles.append(bars[0]) # Add bold black asterisks next to the winning bars for each category (could be multiple) #for j, metric in enumerate(pivot_df.index): # max_value = pivot_df.loc[metric].max() # max_indices = pivot_df.loc[metric][pivot_df.loc[metric] == max_value].index # for max_name in max_indices: # max_index = list(pivot_df.columns).index(max_name) # ax.text(max_value + 0.01, j + max_index * bar_width - bar_width / 4, '*', # color='black', fontsize=12, fontweight='bold', ha='center', va='center') # Set labels, ticks, and title plt.xlabel('Value', fontsize=44) # Adjusted font size ax.set_yticks(indices + bar_width * 1.5) ax.set_xlim([0, 1]) ax.set_yticklabels(pivot_df.index) # make the xticklabels size 24 ax.tick_params(axis='x') ax.set_title(title, fontsize=44) # Adjusted font size # Setting font size for tick labels for label in plt.gca().get_xticklabels(): label.set_fontsize(32) # Adjusted font size for label in plt.gca().get_yticklabels(): label.set_fontsize(32) # Adjusted font size # Create two separate legends if engineered_handles: legend1 = fig.legend( engineered_handles[::-1], [emb for emb in engineered_embeddings if emb in ordered_columns][::-1], loc='center left', bbox_to_anchor=(1, 0.4), title="Engineered Embeddings", title_fontsize=24) # Adjusted font size if deep_learning_handles: legend2 = fig.legend( deep_learning_handles[::-1], [emb for emb in deep_learning_embeddings if emb in ordered_columns][::-1], loc='center left', bbox_to_anchor=(1, 0.6), title="Learned Embeddings", title_fontsize=24) # Adjusted font size # Adjust legend text size if engineered_handles: ax.add_artist(legend1) for text in legend1.get_texts(): text.set_fontsize(22) # Adjusted font size for handle in legend1.legendHandles: if isinstance(handle, mpatches.Patch): handle.set_height(15) # Adjust height handle.set_width(20) # Adjust width elif hasattr(handle, '_sizes'): handle._sizes = [200] # Increase marker size in the legend if deep_learning_handles: ax.add_artist(legend2) for text in legend2.get_texts(): text.set_fontsize(22) # Adjusted font size for handle in legend2.legendHandles: if isinstance(handle, mpatches.Patch): handle.set_height(15) # Adjust height handle.set_width(20) # Adjust width elif hasattr(handle, '_sizes'): handle._sizes = [200] # Increase marker size in the legend plt.tight_layout() # Adjust layout to make room for the legends # Save the plot to a file plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.show() def prepare_data_for_bar(results_dir, task, split, thresh=None): fname = f"{task}_{split}FOs_results.csv" if thresh is not None: fname = f"{task}_{split}FOs_{thresh}thresh_results.csv" image_save_path = results_dir + '/figures/' + fname.split('_results.csv')[0]+'_barchart.png' data = pd.read_csv(f"{results_dir}/{fname}") data = data.loc[ data['Model Name'].isin(['best', 'fo_puncta_ml', 'esm2_t33_650M_UR50D', 'prot_t5_xl_half_uniref50_enc']) ] data = pd.DataFrame(data = { 'Name': data['Model Name'].tolist() * 5, 'Metric': ['Accuracy', 'Accuracy', 'Accuracy','Accuracy', 'Precision', 'Precision', 'Precision', 'Precision', 'Recall', 'Recall', 'Recall', 'Recall', 'F1', 'F1', 'F1','F1', 'AUROC', 'AUROC', 'AUROC','AUROC'], 'Value': data['Accuracy'].tolist() + data['Precision'].tolist() + data['Recall'].tolist() + data['F1 Score'].tolist() + data['AUROC'].tolist() } ) rename_dict = {'fo_puncta_ml': 'FOdb', 'esm2_t33_650M_UR50D':'ESM-2-650M', 'best':'FusOn-pLM', 'prot_t5_xl_half_uniref50_enc': 'ProtT5-XL-U50'} data['Name'] = data['Name'].map(rename_dict) return data, image_save_path def make_all_final_bar_charts(results_dir): # Puncta verification data, image_save_path = prepare_data_for_bar(results_dir,"formation","verification",thresh=0.83) data_cp = data.copy(deep=True) data_cp["Value"] = data_cp["Value"].round(3) data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False) make_final_bar(data, "Puncta Propensity", image_save_path) # Nucleus verification data, image_save_path = prepare_data_for_bar(results_dir,"nucleus","verification",thresh=None) data_cp = data.copy(deep=True) data_cp["Value"] = data_cp["Value"].round(3) data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False) make_final_bar(data, "Nucleus Localization", image_save_path) # Cytoplasm verification data, image_save_path = prepare_data_for_bar(results_dir,"cytoplasm","verification",thresh=None) data_cp = data.copy(deep=True) data_cp["Value"] = data_cp["Value"].round(3) data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False) make_final_bar(data, "Cytoplasm Localization", image_save_path) def main(): # Read in the input data results_dir="results/final" os.makedirs(f"{results_dir}/figures",exist_ok=True) make_all_final_bar_charts(results_dir) if __name__ == '__main__': main()