svincoff's picture
puncta benchmark
8d9d9da
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import pandas as pd
import numpy as np
import os
import matplotlib.colors as mcolors
from fuson_plm.utils.visualizing import set_font
fo_puncta_db_training_thresh31 = pd.DataFrame(data={
'Model Type': ['fo_puncta_ml'],
'Model Name': ['fo_puncta_ml_literature'],
'Model Epoch': np.nan,
'Accuracy': 0.81,
'Precision': 0.78,
'Recall': 0.98,
'F1 Score': 0.87,
'AUROC': 0.88,
'AUPRC': 0.94
})
fo_puncta_db_verification_thresh83 = pd.DataFrame(data={
'Model Type': ['fo_puncta_ml'],
'Model Name': ['fo_puncta_ml_literature'],
'Model Epoch': np.nan,
'Accuracy': 0.79,
'Precision': 0.81,
'Recall': 0.89,
'F1 Score': 0.85,
'AUROC': 0.73,
'AUPRC': 0.82
})
# Method for lengthening the model name
def lengthen_model_name(row):
name = row['Model Name']
epoch = row['Model Epoch']
if 'esm' in name:
return name
if 'puncta' in name:
return name
return f'{name}_e{epoch}'
# Method for shortening the model name for display
def shorten_model_name(row):
name = row['Model Name']
epoch = row['Model Epoch']
if 'esm' in name:
return 'ESM-2-650M'
if name=='fo_puncta_ml':
return 'FO-Puncta-ML'
if name=='fo_puncta_ml_literature':
return 'FO-Puncta-ML Lit'
if name=="prot_t5_xl_half_uniref50_enc":
return 'ProtT5-XL-U50' # this is waht they call it in the paper
if 'snp_' in name:
prob_type = 'snp'
elif 'uniform_' in name:
prob_type = 'uni'
layers = name.split('layers')[0].split('_')[-1]
dt = name.split('mask')[1].split('-', 1)[1]
return f'{prob_type}_{layers}L_{dt}_e{epoch}'
def make_final_bar(dataframe, title, save_path):
set_font()
df = dataframe.copy(deep=True)
# Pivot the DataFrame to have metrics as rows and names as columns, and reorder columns
pivot_df = df.pivot(index='Metric', columns='Name', values='Value')
ordered_columns = [x for x in ['FOdb','ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM'] if x in pivot_df.columns]
pivot_df = pivot_df[ordered_columns]
# Define the groups
engineered_embeddings = ['FOdb']
deep_learning_embeddings = ['ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM']
# Reorder the metrics
metric_order = ['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'][::-1]
pivot_df = pivot_df.reindex(metric_order)
# Plotting
fig, ax = plt.subplots(figsize=(8, 6), dpi=300) # Increased figure size for better legend placement
# Define bar width and positions
bar_width = 0.2
indices = np.arange(len(pivot_df))
# Use a colorblind-friendly color scheme from tableau
color_map = {
#'One-Hot': "#999999",
'FOdb': "#E69F00",
'ESM-2-650M': "#F0E442",
'FusOn-pLM': "#FF69B4",
'ProtT5-XL-U50': "#00ccff" # light blue
}
colors = [color_map[col] for col in ordered_columns]
# Plot bars for each category and add them to appropriate legend groups
engineered_handles = []
deep_learning_handles = []
for i, (name, color) in enumerate(zip(pivot_df.columns, colors)):
bars = ax.barh(indices + i * bar_width, pivot_df[name], bar_width, label=name, color=color)
if name in engineered_embeddings:
engineered_handles.append(bars[0])
else:
deep_learning_handles.append(bars[0])
# Add bold black asterisks next to the winning bars for each category (could be multiple)
#for j, metric in enumerate(pivot_df.index):
# max_value = pivot_df.loc[metric].max()
# max_indices = pivot_df.loc[metric][pivot_df.loc[metric] == max_value].index
# for max_name in max_indices:
# max_index = list(pivot_df.columns).index(max_name)
# ax.text(max_value + 0.01, j + max_index * bar_width - bar_width / 4, '*',
# color='black', fontsize=12, fontweight='bold', ha='center', va='center')
# Set labels, ticks, and title
plt.xlabel('Value', fontsize=44) # Adjusted font size
ax.set_yticks(indices + bar_width * 1.5)
ax.set_xlim([0, 1])
ax.set_yticklabels(pivot_df.index)
# make the xticklabels size 24
ax.tick_params(axis='x')
ax.set_title(title, fontsize=44) # Adjusted font size
# Setting font size for tick labels
for label in plt.gca().get_xticklabels():
label.set_fontsize(32) # Adjusted font size
for label in plt.gca().get_yticklabels():
label.set_fontsize(32) # Adjusted font size
# Create two separate legends
if engineered_handles:
legend1 = fig.legend(
engineered_handles[::-1],
[emb for emb in engineered_embeddings if emb in ordered_columns][::-1],
loc='center left',
bbox_to_anchor=(1, 0.4),
title="Engineered Embeddings",
title_fontsize=24) # Adjusted font size
if deep_learning_handles:
legend2 = fig.legend(
deep_learning_handles[::-1],
[emb for emb in deep_learning_embeddings if emb in ordered_columns][::-1],
loc='center left',
bbox_to_anchor=(1, 0.6),
title="Learned Embeddings",
title_fontsize=24) # Adjusted font size
# Adjust legend text size
if engineered_handles:
ax.add_artist(legend1)
for text in legend1.get_texts():
text.set_fontsize(22) # Adjusted font size
for handle in legend1.legendHandles:
if isinstance(handle, mpatches.Patch):
handle.set_height(15) # Adjust height
handle.set_width(20) # Adjust width
elif hasattr(handle, '_sizes'):
handle._sizes = [200] # Increase marker size in the legend
if deep_learning_handles:
ax.add_artist(legend2)
for text in legend2.get_texts():
text.set_fontsize(22) # Adjusted font size
for handle in legend2.legendHandles:
if isinstance(handle, mpatches.Patch):
handle.set_height(15) # Adjust height
handle.set_width(20) # Adjust width
elif hasattr(handle, '_sizes'):
handle._sizes = [200] # Increase marker size in the legend
plt.tight_layout() # Adjust layout to make room for the legends
# Save the plot to a file
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def prepare_data_for_bar(results_dir, task, split, thresh=None):
fname = f"{task}_{split}FOs_results.csv"
if thresh is not None: fname = f"{task}_{split}FOs_{thresh}thresh_results.csv"
image_save_path = results_dir + '/figures/' + fname.split('_results.csv')[0]+'_barchart.png'
data = pd.read_csv(f"{results_dir}/{fname}")
data = data.loc[
data['Model Name'].isin(['best',
'fo_puncta_ml',
'esm2_t33_650M_UR50D',
'prot_t5_xl_half_uniref50_enc'])
]
data = pd.DataFrame(data = {
'Name': data['Model Name'].tolist() * 5,
'Metric': ['Accuracy', 'Accuracy', 'Accuracy','Accuracy',
'Precision', 'Precision', 'Precision', 'Precision',
'Recall', 'Recall', 'Recall', 'Recall',
'F1', 'F1', 'F1','F1',
'AUROC', 'AUROC', 'AUROC','AUROC'],
'Value': data['Accuracy'].tolist() + data['Precision'].tolist() + data['Recall'].tolist() + data['F1 Score'].tolist() + data['AUROC'].tolist()
}
)
rename_dict = {'fo_puncta_ml': 'FOdb',
'esm2_t33_650M_UR50D':'ESM-2-650M',
'best':'FusOn-pLM',
'prot_t5_xl_half_uniref50_enc': 'ProtT5-XL-U50'}
data['Name'] = data['Name'].map(rename_dict)
return data, image_save_path
def make_all_final_bar_charts(results_dir):
# Puncta verification
data, image_save_path = prepare_data_for_bar(results_dir,"formation","verification",thresh=0.83)
data_cp = data.copy(deep=True)
data_cp["Value"] = data_cp["Value"].round(3)
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False)
make_final_bar(data, "Puncta Propensity", image_save_path)
# Nucleus verification
data, image_save_path = prepare_data_for_bar(results_dir,"nucleus","verification",thresh=None)
data_cp = data.copy(deep=True)
data_cp["Value"] = data_cp["Value"].round(3)
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False)
make_final_bar(data, "Nucleus Localization", image_save_path)
# Cytoplasm verification
data, image_save_path = prepare_data_for_bar(results_dir,"cytoplasm","verification",thresh=None)
data_cp = data.copy(deep=True)
data_cp["Value"] = data_cp["Value"].round(3)
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False)
make_final_bar(data, "Cytoplasm Localization", image_save_path)
def main():
# Read in the input data
results_dir="results/final"
os.makedirs(f"{results_dir}/figures",exist_ok=True)
make_all_final_bar_charts(results_dir)
if __name__ == '__main__':
main()