|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as mpatches |
|
import seaborn as sns |
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import matplotlib.colors as mcolors |
|
from fuson_plm.utils.visualizing import set_font |
|
|
|
fo_puncta_db_training_thresh31 = pd.DataFrame(data={ |
|
'Model Type': ['fo_puncta_ml'], |
|
'Model Name': ['fo_puncta_ml_literature'], |
|
'Model Epoch': np.nan, |
|
'Accuracy': 0.81, |
|
'Precision': 0.78, |
|
'Recall': 0.98, |
|
'F1 Score': 0.87, |
|
'AUROC': 0.88, |
|
'AUPRC': 0.94 |
|
}) |
|
|
|
fo_puncta_db_verification_thresh83 = pd.DataFrame(data={ |
|
'Model Type': ['fo_puncta_ml'], |
|
'Model Name': ['fo_puncta_ml_literature'], |
|
'Model Epoch': np.nan, |
|
'Accuracy': 0.79, |
|
'Precision': 0.81, |
|
'Recall': 0.89, |
|
'F1 Score': 0.85, |
|
'AUROC': 0.73, |
|
'AUPRC': 0.82 |
|
}) |
|
|
|
|
|
def lengthen_model_name(row): |
|
name = row['Model Name'] |
|
epoch = row['Model Epoch'] |
|
|
|
if 'esm' in name: |
|
return name |
|
if 'puncta' in name: |
|
return name |
|
|
|
return f'{name}_e{epoch}' |
|
|
|
|
|
def shorten_model_name(row): |
|
name = row['Model Name'] |
|
epoch = row['Model Epoch'] |
|
|
|
if 'esm' in name: |
|
return 'ESM-2-650M' |
|
if name=='fo_puncta_ml': |
|
return 'FO-Puncta-ML' |
|
if name=='fo_puncta_ml_literature': |
|
return 'FO-Puncta-ML Lit' |
|
if name=="prot_t5_xl_half_uniref50_enc": |
|
return 'ProtT5-XL-U50' |
|
|
|
if 'snp_' in name: |
|
prob_type = 'snp' |
|
elif 'uniform_' in name: |
|
prob_type = 'uni' |
|
|
|
layers = name.split('layers')[0].split('_')[-1] |
|
dt = name.split('mask')[1].split('-', 1)[1] |
|
|
|
return f'{prob_type}_{layers}L_{dt}_e{epoch}' |
|
|
|
def make_final_bar(dataframe, title, save_path): |
|
set_font() |
|
df = dataframe.copy(deep=True) |
|
|
|
|
|
pivot_df = df.pivot(index='Metric', columns='Name', values='Value') |
|
ordered_columns = [x for x in ['FOdb','ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM'] if x in pivot_df.columns] |
|
pivot_df = pivot_df[ordered_columns] |
|
|
|
|
|
engineered_embeddings = ['FOdb'] |
|
deep_learning_embeddings = ['ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM'] |
|
|
|
|
|
metric_order = ['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'][::-1] |
|
pivot_df = pivot_df.reindex(metric_order) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 6), dpi=300) |
|
|
|
|
|
bar_width = 0.2 |
|
indices = np.arange(len(pivot_df)) |
|
|
|
|
|
color_map = { |
|
|
|
'FOdb': "#E69F00", |
|
'ESM-2-650M': "#F0E442", |
|
'FusOn-pLM': "#FF69B4", |
|
'ProtT5-XL-U50': "#00ccff" |
|
} |
|
colors = [color_map[col] for col in ordered_columns] |
|
|
|
|
|
engineered_handles = [] |
|
deep_learning_handles = [] |
|
for i, (name, color) in enumerate(zip(pivot_df.columns, colors)): |
|
bars = ax.barh(indices + i * bar_width, pivot_df[name], bar_width, label=name, color=color) |
|
if name in engineered_embeddings: |
|
engineered_handles.append(bars[0]) |
|
else: |
|
deep_learning_handles.append(bars[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.xlabel('Value', fontsize=44) |
|
ax.set_yticks(indices + bar_width * 1.5) |
|
ax.set_xlim([0, 1]) |
|
ax.set_yticklabels(pivot_df.index) |
|
|
|
ax.tick_params(axis='x') |
|
ax.set_title(title, fontsize=44) |
|
|
|
|
|
for label in plt.gca().get_xticklabels(): |
|
label.set_fontsize(32) |
|
for label in plt.gca().get_yticklabels(): |
|
label.set_fontsize(32) |
|
|
|
|
|
if engineered_handles: |
|
legend1 = fig.legend( |
|
engineered_handles[::-1], |
|
[emb for emb in engineered_embeddings if emb in ordered_columns][::-1], |
|
loc='center left', |
|
bbox_to_anchor=(1, 0.4), |
|
title="Engineered Embeddings", |
|
title_fontsize=24) |
|
if deep_learning_handles: |
|
legend2 = fig.legend( |
|
deep_learning_handles[::-1], |
|
[emb for emb in deep_learning_embeddings if emb in ordered_columns][::-1], |
|
loc='center left', |
|
bbox_to_anchor=(1, 0.6), |
|
title="Learned Embeddings", |
|
title_fontsize=24) |
|
|
|
|
|
if engineered_handles: |
|
ax.add_artist(legend1) |
|
for text in legend1.get_texts(): |
|
text.set_fontsize(22) |
|
for handle in legend1.legendHandles: |
|
if isinstance(handle, mpatches.Patch): |
|
handle.set_height(15) |
|
handle.set_width(20) |
|
elif hasattr(handle, '_sizes'): |
|
handle._sizes = [200] |
|
|
|
if deep_learning_handles: |
|
ax.add_artist(legend2) |
|
for text in legend2.get_texts(): |
|
text.set_fontsize(22) |
|
for handle in legend2.legendHandles: |
|
if isinstance(handle, mpatches.Patch): |
|
handle.set_height(15) |
|
handle.set_width(20) |
|
elif hasattr(handle, '_sizes'): |
|
handle._sizes = [200] |
|
|
|
plt.tight_layout() |
|
|
|
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
|
plt.show() |
|
|
|
def prepare_data_for_bar(results_dir, task, split, thresh=None): |
|
fname = f"{task}_{split}FOs_results.csv" |
|
if thresh is not None: fname = f"{task}_{split}FOs_{thresh}thresh_results.csv" |
|
image_save_path = results_dir + '/figures/' + fname.split('_results.csv')[0]+'_barchart.png' |
|
|
|
data = pd.read_csv(f"{results_dir}/{fname}") |
|
data = data.loc[ |
|
data['Model Name'].isin(['best', |
|
'fo_puncta_ml', |
|
'esm2_t33_650M_UR50D', |
|
'prot_t5_xl_half_uniref50_enc']) |
|
] |
|
data = pd.DataFrame(data = { |
|
'Name': data['Model Name'].tolist() * 5, |
|
'Metric': ['Accuracy', 'Accuracy', 'Accuracy','Accuracy', |
|
'Precision', 'Precision', 'Precision', 'Precision', |
|
'Recall', 'Recall', 'Recall', 'Recall', |
|
'F1', 'F1', 'F1','F1', |
|
'AUROC', 'AUROC', 'AUROC','AUROC'], |
|
'Value': data['Accuracy'].tolist() + data['Precision'].tolist() + data['Recall'].tolist() + data['F1 Score'].tolist() + data['AUROC'].tolist() |
|
} |
|
) |
|
rename_dict = {'fo_puncta_ml': 'FOdb', |
|
'esm2_t33_650M_UR50D':'ESM-2-650M', |
|
'best':'FusOn-pLM', |
|
'prot_t5_xl_half_uniref50_enc': 'ProtT5-XL-U50'} |
|
data['Name'] = data['Name'].map(rename_dict) |
|
return data, image_save_path |
|
|
|
def make_all_final_bar_charts(results_dir): |
|
|
|
data, image_save_path = prepare_data_for_bar(results_dir,"formation","verification",thresh=0.83) |
|
data_cp = data.copy(deep=True) |
|
data_cp["Value"] = data_cp["Value"].round(3) |
|
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False) |
|
make_final_bar(data, "Puncta Propensity", image_save_path) |
|
|
|
|
|
data, image_save_path = prepare_data_for_bar(results_dir,"nucleus","verification",thresh=None) |
|
data_cp = data.copy(deep=True) |
|
data_cp["Value"] = data_cp["Value"].round(3) |
|
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False) |
|
make_final_bar(data, "Nucleus Localization", image_save_path) |
|
|
|
|
|
data, image_save_path = prepare_data_for_bar(results_dir,"cytoplasm","verification",thresh=None) |
|
data_cp = data.copy(deep=True) |
|
data_cp["Value"] = data_cp["Value"].round(3) |
|
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False) |
|
make_final_bar(data, "Cytoplasm Localization", image_save_path) |
|
|
|
def main(): |
|
|
|
results_dir="results/final" |
|
os.makedirs(f"{results_dir}/figures",exist_ok=True) |
|
make_all_final_bar_charts(results_dir) |
|
|
|
if __name__ == '__main__': |
|
main() |