File size: 9,309 Bytes
8d9d9da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import pandas as pd
import numpy as np
import os
import matplotlib.colors as mcolors
from fuson_plm.utils.visualizing import set_font
fo_puncta_db_training_thresh31 = pd.DataFrame(data={
'Model Type': ['fo_puncta_ml'],
'Model Name': ['fo_puncta_ml_literature'],
'Model Epoch': np.nan,
'Accuracy': 0.81,
'Precision': 0.78,
'Recall': 0.98,
'F1 Score': 0.87,
'AUROC': 0.88,
'AUPRC': 0.94
})
fo_puncta_db_verification_thresh83 = pd.DataFrame(data={
'Model Type': ['fo_puncta_ml'],
'Model Name': ['fo_puncta_ml_literature'],
'Model Epoch': np.nan,
'Accuracy': 0.79,
'Precision': 0.81,
'Recall': 0.89,
'F1 Score': 0.85,
'AUROC': 0.73,
'AUPRC': 0.82
})
# Method for lengthening the model name
def lengthen_model_name(row):
name = row['Model Name']
epoch = row['Model Epoch']
if 'esm' in name:
return name
if 'puncta' in name:
return name
return f'{name}_e{epoch}'
# Method for shortening the model name for display
def shorten_model_name(row):
name = row['Model Name']
epoch = row['Model Epoch']
if 'esm' in name:
return 'ESM-2-650M'
if name=='fo_puncta_ml':
return 'FO-Puncta-ML'
if name=='fo_puncta_ml_literature':
return 'FO-Puncta-ML Lit'
if name=="prot_t5_xl_half_uniref50_enc":
return 'ProtT5-XL-U50' # this is waht they call it in the paper
if 'snp_' in name:
prob_type = 'snp'
elif 'uniform_' in name:
prob_type = 'uni'
layers = name.split('layers')[0].split('_')[-1]
dt = name.split('mask')[1].split('-', 1)[1]
return f'{prob_type}_{layers}L_{dt}_e{epoch}'
def make_final_bar(dataframe, title, save_path):
set_font()
df = dataframe.copy(deep=True)
# Pivot the DataFrame to have metrics as rows and names as columns, and reorder columns
pivot_df = df.pivot(index='Metric', columns='Name', values='Value')
ordered_columns = [x for x in ['FOdb','ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM'] if x in pivot_df.columns]
pivot_df = pivot_df[ordered_columns]
# Define the groups
engineered_embeddings = ['FOdb']
deep_learning_embeddings = ['ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM']
# Reorder the metrics
metric_order = ['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'][::-1]
pivot_df = pivot_df.reindex(metric_order)
# Plotting
fig, ax = plt.subplots(figsize=(8, 6), dpi=300) # Increased figure size for better legend placement
# Define bar width and positions
bar_width = 0.2
indices = np.arange(len(pivot_df))
# Use a colorblind-friendly color scheme from tableau
color_map = {
#'One-Hot': "#999999",
'FOdb': "#E69F00",
'ESM-2-650M': "#F0E442",
'FusOn-pLM': "#FF69B4",
'ProtT5-XL-U50': "#00ccff" # light blue
}
colors = [color_map[col] for col in ordered_columns]
# Plot bars for each category and add them to appropriate legend groups
engineered_handles = []
deep_learning_handles = []
for i, (name, color) in enumerate(zip(pivot_df.columns, colors)):
bars = ax.barh(indices + i * bar_width, pivot_df[name], bar_width, label=name, color=color)
if name in engineered_embeddings:
engineered_handles.append(bars[0])
else:
deep_learning_handles.append(bars[0])
# Add bold black asterisks next to the winning bars for each category (could be multiple)
#for j, metric in enumerate(pivot_df.index):
# max_value = pivot_df.loc[metric].max()
# max_indices = pivot_df.loc[metric][pivot_df.loc[metric] == max_value].index
# for max_name in max_indices:
# max_index = list(pivot_df.columns).index(max_name)
# ax.text(max_value + 0.01, j + max_index * bar_width - bar_width / 4, '*',
# color='black', fontsize=12, fontweight='bold', ha='center', va='center')
# Set labels, ticks, and title
plt.xlabel('Value', fontsize=44) # Adjusted font size
ax.set_yticks(indices + bar_width * 1.5)
ax.set_xlim([0, 1])
ax.set_yticklabels(pivot_df.index)
# make the xticklabels size 24
ax.tick_params(axis='x')
ax.set_title(title, fontsize=44) # Adjusted font size
# Setting font size for tick labels
for label in plt.gca().get_xticklabels():
label.set_fontsize(32) # Adjusted font size
for label in plt.gca().get_yticklabels():
label.set_fontsize(32) # Adjusted font size
# Create two separate legends
if engineered_handles:
legend1 = fig.legend(
engineered_handles[::-1],
[emb for emb in engineered_embeddings if emb in ordered_columns][::-1],
loc='center left',
bbox_to_anchor=(1, 0.4),
title="Engineered Embeddings",
title_fontsize=24) # Adjusted font size
if deep_learning_handles:
legend2 = fig.legend(
deep_learning_handles[::-1],
[emb for emb in deep_learning_embeddings if emb in ordered_columns][::-1],
loc='center left',
bbox_to_anchor=(1, 0.6),
title="Learned Embeddings",
title_fontsize=24) # Adjusted font size
# Adjust legend text size
if engineered_handles:
ax.add_artist(legend1)
for text in legend1.get_texts():
text.set_fontsize(22) # Adjusted font size
for handle in legend1.legendHandles:
if isinstance(handle, mpatches.Patch):
handle.set_height(15) # Adjust height
handle.set_width(20) # Adjust width
elif hasattr(handle, '_sizes'):
handle._sizes = [200] # Increase marker size in the legend
if deep_learning_handles:
ax.add_artist(legend2)
for text in legend2.get_texts():
text.set_fontsize(22) # Adjusted font size
for handle in legend2.legendHandles:
if isinstance(handle, mpatches.Patch):
handle.set_height(15) # Adjust height
handle.set_width(20) # Adjust width
elif hasattr(handle, '_sizes'):
handle._sizes = [200] # Increase marker size in the legend
plt.tight_layout() # Adjust layout to make room for the legends
# Save the plot to a file
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def prepare_data_for_bar(results_dir, task, split, thresh=None):
fname = f"{task}_{split}FOs_results.csv"
if thresh is not None: fname = f"{task}_{split}FOs_{thresh}thresh_results.csv"
image_save_path = results_dir + '/figures/' + fname.split('_results.csv')[0]+'_barchart.png'
data = pd.read_csv(f"{results_dir}/{fname}")
data = data.loc[
data['Model Name'].isin(['best',
'fo_puncta_ml',
'esm2_t33_650M_UR50D',
'prot_t5_xl_half_uniref50_enc'])
]
data = pd.DataFrame(data = {
'Name': data['Model Name'].tolist() * 5,
'Metric': ['Accuracy', 'Accuracy', 'Accuracy','Accuracy',
'Precision', 'Precision', 'Precision', 'Precision',
'Recall', 'Recall', 'Recall', 'Recall',
'F1', 'F1', 'F1','F1',
'AUROC', 'AUROC', 'AUROC','AUROC'],
'Value': data['Accuracy'].tolist() + data['Precision'].tolist() + data['Recall'].tolist() + data['F1 Score'].tolist() + data['AUROC'].tolist()
}
)
rename_dict = {'fo_puncta_ml': 'FOdb',
'esm2_t33_650M_UR50D':'ESM-2-650M',
'best':'FusOn-pLM',
'prot_t5_xl_half_uniref50_enc': 'ProtT5-XL-U50'}
data['Name'] = data['Name'].map(rename_dict)
return data, image_save_path
def make_all_final_bar_charts(results_dir):
# Puncta verification
data, image_save_path = prepare_data_for_bar(results_dir,"formation","verification",thresh=0.83)
data_cp = data.copy(deep=True)
data_cp["Value"] = data_cp["Value"].round(3)
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False)
make_final_bar(data, "Puncta Propensity", image_save_path)
# Nucleus verification
data, image_save_path = prepare_data_for_bar(results_dir,"nucleus","verification",thresh=None)
data_cp = data.copy(deep=True)
data_cp["Value"] = data_cp["Value"].round(3)
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False)
make_final_bar(data, "Nucleus Localization", image_save_path)
# Cytoplasm verification
data, image_save_path = prepare_data_for_bar(results_dir,"cytoplasm","verification",thresh=None)
data_cp = data.copy(deep=True)
data_cp["Value"] = data_cp["Value"].round(3)
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False)
make_final_bar(data, "Cytoplasm Localization", image_save_path)
def main():
# Read in the input data
results_dir="results/final"
os.makedirs(f"{results_dir}/figures",exist_ok=True)
make_all_final_bar_charts(results_dir)
if __name__ == '__main__':
main() |