puncta benchmark
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import pandas as pd
import numpy as np
import os
import matplotlib.colors as mcolors
from fuson_plm.utils.visualizing import set_font
fo_puncta_db_training_thresh31 = pd.DataFrame(data={
'Model Type': ['fo_puncta_ml'],
'Model Name': ['fo_puncta_ml_literature'],
'Model Epoch': np.nan,
'Accuracy': 0.81,
'Precision': 0.78,
'Recall': 0.98,
'F1 Score': 0.87,
'AUROC': 0.88,
'AUPRC': 0.94
fo_puncta_db_verification_thresh83 = pd.DataFrame(data={
'Model Type': ['fo_puncta_ml'],
'Model Name': ['fo_puncta_ml_literature'],
'Model Epoch': np.nan,
'Accuracy': 0.79,
'Precision': 0.81,
'Recall': 0.89,
'F1 Score': 0.85,
'AUROC': 0.73,
'AUPRC': 0.82
# Method for lengthening the model name
def lengthen_model_name(row):
name = row['Model Name']
epoch = row['Model Epoch']
if 'esm' in name:
return name
if 'puncta' in name:
return name
return f'{name}_e{epoch}'
# Method for shortening the model name for display
def shorten_model_name(row):
name = row['Model Name']
epoch = row['Model Epoch']
if 'esm' in name:
return 'ESM-2-650M'
if name=='fo_puncta_ml':
return 'FO-Puncta-ML'
if name=='fo_puncta_ml_literature':
return 'FO-Puncta-ML Lit'
if name=="prot_t5_xl_half_uniref50_enc":
return 'ProtT5-XL-U50' # this is waht they call it in the paper
if 'snp_' in name:
prob_type = 'snp'
elif 'uniform_' in name:
prob_type = 'uni'
layers = name.split('layers')[0].split('_')[-1]
dt = name.split('mask')[1].split('-', 1)[1]
return f'{prob_type}_{layers}L_{dt}_e{epoch}'
def make_final_bar(dataframe, title, save_path):
df = dataframe.copy(deep=True)
# Pivot the DataFrame to have metrics as rows and names as columns, and reorder columns
pivot_df = df.pivot(index='Metric', columns='Name', values='Value')
ordered_columns = [x for x in ['FOdb','ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM'] if x in pivot_df.columns]
pivot_df = pivot_df[ordered_columns]
# Define the groups
engineered_embeddings = ['FOdb']
deep_learning_embeddings = ['ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM']
# Reorder the metrics
metric_order = ['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'][::-1]
pivot_df = pivot_df.reindex(metric_order)
# Plotting
fig, ax = plt.subplots(figsize=(8, 6), dpi=300) # Increased figure size for better legend placement
# Define bar width and positions
bar_width = 0.2
indices = np.arange(len(pivot_df))
# Use a colorblind-friendly color scheme from tableau
color_map = {
#'One-Hot': "#999999",
'FOdb': "#E69F00",
'ESM-2-650M': "#F0E442",
'FusOn-pLM': "#FF69B4",
'ProtT5-XL-U50': "#00ccff" # light blue
colors = [color_map[col] for col in ordered_columns]
# Plot bars for each category and add them to appropriate legend groups
engineered_handles = []
deep_learning_handles = []
for i, (name, color) in enumerate(zip(pivot_df.columns, colors)):
bars = ax.barh(indices + i * bar_width, pivot_df[name], bar_width, label=name, color=color)
if name in engineered_embeddings:
# Add bold black asterisks next to the winning bars for each category (could be multiple)
#for j, metric in enumerate(pivot_df.index):
# max_value = pivot_df.loc[metric].max()
# max_indices = pivot_df.loc[metric][pivot_df.loc[metric] == max_value].index
# for max_name in max_indices:
# max_index = list(pivot_df.columns).index(max_name)
# ax.text(max_value + 0.01, j + max_index * bar_width - bar_width / 4, '*',
# color='black', fontsize=12, fontweight='bold', ha='center', va='center')
# Set labels, ticks, and title
plt.xlabel('Value', fontsize=44) # Adjusted font size
ax.set_yticks(indices + bar_width * 1.5)
ax.set_xlim([0, 1])
# make the xticklabels size 24
ax.set_title(title, fontsize=44) # Adjusted font size
# Setting font size for tick labels
for label in plt.gca().get_xticklabels():
label.set_fontsize(32) # Adjusted font size
for label in plt.gca().get_yticklabels():
label.set_fontsize(32) # Adjusted font size
# Create two separate legends
if engineered_handles:
legend1 = fig.legend(
[emb for emb in engineered_embeddings if emb in ordered_columns][::-1],
loc='center left',
bbox_to_anchor=(1, 0.4),
title="Engineered Embeddings",
title_fontsize=24) # Adjusted font size
if deep_learning_handles:
legend2 = fig.legend(
[emb for emb in deep_learning_embeddings if emb in ordered_columns][::-1],
loc='center left',
bbox_to_anchor=(1, 0.6),
title="Learned Embeddings",
title_fontsize=24) # Adjusted font size
# Adjust legend text size
if engineered_handles:
for text in legend1.get_texts():
text.set_fontsize(22) # Adjusted font size
for handle in legend1.legendHandles:
if isinstance(handle, mpatches.Patch):
handle.set_height(15) # Adjust height
handle.set_width(20) # Adjust width
elif hasattr(handle, '_sizes'):
handle._sizes = [200] # Increase marker size in the legend
if deep_learning_handles:
for text in legend2.get_texts():
text.set_fontsize(22) # Adjusted font size
for handle in legend2.legendHandles:
if isinstance(handle, mpatches.Patch):
handle.set_height(15) # Adjust height
handle.set_width(20) # Adjust width
elif hasattr(handle, '_sizes'):
handle._sizes = [200] # Increase marker size in the legend
plt.tight_layout() # Adjust layout to make room for the legends
# Save the plot to a file
plt.savefig(save_path, dpi=300, bbox_inches='tight')
def prepare_data_for_bar(results_dir, task, split, thresh=None):
fname = f"{task}_{split}FOs_results.csv"
if thresh is not None: fname = f"{task}_{split}FOs_{thresh}thresh_results.csv"
image_save_path = results_dir + '/figures/' + fname.split('_results.csv')[0]+'_barchart.png'
data = pd.read_csv(f"{results_dir}/{fname}")
data = data.loc[
data['Model Name'].isin(['best',
data = pd.DataFrame(data = {
'Name': data['Model Name'].tolist() * 5,
'Metric': ['Accuracy', 'Accuracy', 'Accuracy','Accuracy',
'Precision', 'Precision', 'Precision', 'Precision',
'Recall', 'Recall', 'Recall', 'Recall',
'F1', 'F1', 'F1','F1',
'Value': data['Accuracy'].tolist() + data['Precision'].tolist() + data['Recall'].tolist() + data['F1 Score'].tolist() + data['AUROC'].tolist()
rename_dict = {'fo_puncta_ml': 'FOdb',
'prot_t5_xl_half_uniref50_enc': 'ProtT5-XL-U50'}
data['Name'] = data['Name'].map(rename_dict)
return data, image_save_path
def make_all_final_bar_charts(results_dir):
# Puncta verification
data, image_save_path = prepare_data_for_bar(results_dir,"formation","verification",thresh=0.83)
data_cp = data.copy(deep=True)
data_cp["Value"] = data_cp["Value"].round(3)
make_final_bar(data, "Puncta Propensity", image_save_path)
# Nucleus verification
data, image_save_path = prepare_data_for_bar(results_dir,"nucleus","verification",thresh=None)
data_cp = data.copy(deep=True)
data_cp["Value"] = data_cp["Value"].round(3)
make_final_bar(data, "Nucleus Localization", image_save_path)
# Cytoplasm verification
data, image_save_path = prepare_data_for_bar(results_dir,"cytoplasm","verification",thresh=None)
data_cp = data.copy(deep=True)
data_cp["Value"] = data_cp["Value"].round(3)
make_final_bar(data, "Cytoplasm Localization", image_save_path)
def main():
# Read in the input data
if __name__ == '__main__':