svincoff's picture
caid benchmark
bae913a
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import os
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
from matplotlib import font_manager
import matplotlib.patches as patches
from sklearn.metrics import roc_curve, auc, r2_score
from fuson_plm.utils.visualizing import set_font
global caid2_winners, caid2_model_rankings
caid2_winners = pd.DataFrame(data=
{
'Model Name': ['Dispredict3','flDPnn2','flDPnn','flDPlr','flDPlr2','DisoPred',
'IDP-Fusion','ESpritz-D','DeepIDP-2L','disomine','DISOPRED3-diso','IUPred3',
'AlphaFold-rsa','AlphaFold-pLDDT'], # do the top 6 models, and IUPred because it's well-known
'AUROC': [0.838,0.836,0.833,0.827,0.821,0.821,
0.818,0.802,0.800,0.797,0.692,0.755,0.747,0.695],
})
caid2_winners['Model Type'] = ['caid2_competition']*len(caid2_winners)
caid2_winners['Model Epoch'] = [np.nan]*len(caid2_winners)
caid2_model_rankings = {
'Dispredict3': 1,
'flDPnn2': 2,
'flDPnn': 3,
'flDPlr': 4,
'flDPlr2': 5,
'DisoPred': 6,
'IDP-Fusion': 7,
'ESpritz-D': 8,
'DeepIDP-2L': 9,
'disomine': 10,
'DISOPRED3-diso': 35,
'IUPred3': 21,
'AlphaFold-rsa': 24,
'AlphaFold-pLDDT': 34
}
# Method for lengthening the model name
def lengthen_model_name(row):
model_type = row['Model Type']
name = row['Model Name']
epoch = row['Model Epoch']
if 'esm' in name:
return name
if 'puncta' in name:
return name
if model_type=='caid2_competition':
return name
return f'{name}_e{epoch}'
# Method for shortening the model name for display
def shorten_model_name(row):
model_type = row['Model Type']
name = row['Model Name']
epoch = row['Model Epoch']
if 'esm' in name:
return 'ESM-2-650M'
if model_type=='caid2_competition':
return name
if 'snp_' in name:
prob_type = 'snp'
elif 'uniform_' in name:
prob_type = 'uni'
layers = name.split('layers')[0].split('_')[-1]
maskrate = name.split('mask')[1].split('-', 1)[0]
kqv_tag = name.split('layers_')[1].split('_')[0]
dt = name.split('mask')[1].split('-', 1)[1]
return f'{prob_type}_{layers}L_{kqv_tag}_mask{maskrate}_{dt}_e{epoch}'
def make_heatmap(df, results_dir='.', gold_standard_model_name="esm2_t33_650M_UR50D",split="test",thresh=None,ax=None):
# Set font to Ubuntu
set_font()
# Declare columns to compare: metrics
columns_to_compare = ['AUROC']
# Define the literature-reported values for CAID competition winners - only IF the split is not "benchmark"
if not(split=="benchmark"):
df = pd.concat([df,caid2_winners])
# Create Short Model Name and Full Model Name columns for later use
df['Model Epoch'] = df['Model Epoch'].apply(lambda x: str(int(x)) if not(np.isnan(x)) else '')
df['Short Model Name'] = df.apply(lambda row: shorten_model_name(row),axis=1)
df['Full Model Name'] = df.apply(lambda row: lengthen_model_name(row), axis=1)
# Isolate gold standard row for later comparison
gold_standard = df[df['Full Model Name'] == gold_standard_model_name].reset_index(drop=True).iloc[0]
gold_standard_short_model_name = df[df['Full Model Name'] == gold_standard_model_name]['Short Model Name'].item()
# Create a new dataframe for the heatmap; sort by model type and place gold standard on top
heatmap_data = df[['Model Type','Short Model Name','Full Model Name'] + columns_to_compare].copy()
heatmap_data['is_gold_standard'] = (heatmap_data['Full Model Name'] == gold_standard_model_name).astype(int)
heatmap_data = heatmap_data.sort_values(by=['is_gold_standard','Model Type','AUROC'], ascending=[False,True,False]).reset_index(drop=True).drop(columns=['is_gold_standard'])
# Save the original values before calculating differences so we can use them for annotation
original_values = heatmap_data[columns_to_compare].copy()
# Calculate differences from the gold standard
for col in columns_to_compare:
heatmap_data[col] = heatmap_data[col] - gold_standard[col]
# Create a color map where values equal to 0 are white, above are red, and below are blue
cmap = sns.color_palette("coolwarm", as_cmap=True) # other option is diverging_palette(220, 20, as_cmap=True)
### Make the plot
# can plot on a bigger plot, or make it an individual plot
if ax is None:
tallsize = max(8, 8 +.25*(len(heatmap_data)-26))
fig, ax = plt.subplots(1, 1, figsize=(8, tallsize), dpi=300)
# Plot the heatmap with original values as annotations
hm = sns.heatmap(heatmap_data.set_index('Short Model Name').drop(columns=['Model Type','Full Model Name']),
annot=False, fmt='', cmap=cmap, center=0,
cbar_kws={'label': 'Difference from Gold Standard'})
# Explicitly set tick labels to prevent them from being messed up
ax.set_yticklabels(heatmap_data['Short Model Name'], rotation=0, fontsize=12)
# Add padding to the y-axis label
ax.set_ylabel("Short Model Name", labelpad=20) # Increase the labelpad value to add more padding
# Bold any values values that exceed the gold standard
for i in range(original_values.shape[0]):
for j in range(original_values.shape[1]):
value = original_values.iloc[i, j]
if value > gold_standard[columns_to_compare[j]]:
ax.text(j + 0.5, i + 0.5, f'{value:.3f}', ha='center', va='center', fontweight='bold', color='black')
else:
ax.text(j + 0.5, i + 0.5, f'{value:.3f}', ha='center', va='center', color='black')
# Add horizontal lines between different model types
model_type_series = heatmap_data['Model Type'].values
last_index = 0
labels_positions = [] # To store the positions for labels
for i in range(1, len(model_type_series)):
if model_type_series[i] != model_type_series[i - 1]:
hm.axhline(i, color='white', linewidth=8) # Draw a thick white line between groups
labels_positions.append((last_index + i) / 2) # Store the midpoint for labeling
last_index = i
# Add label for the last group
labels_positions.append((last_index + len(model_type_series)) / 2)
# Italic and bold models that win AUROC; apply yellow coloring to gold standard model
for ytick, model_name in enumerate(heatmap_data['Short Model Name']):
if model_name == gold_standard_short_model_name:
# color yellow
label = ax.get_yticklabels()[ytick]
#label.set_color('gold')
label.set_bbox(dict(facecolor='gold', alpha=0.5, edgecolor='gold'))
if model_name != gold_standard_short_model_name:
auroc_value = original_values.loc[ytick, 'AUROC']
# Apply bold and italic for wins on either AUROC or F1 Score
if (auroc_value > gold_standard['AUROC']):
label = ax.get_yticklabels()[ytick]
#label.set_style('italic')
#label.set_weight('bold')
label.set_bbox(dict(facecolor='red', alpha=0.3, edgecolor='red'))
# Make legend
gold_patch = mpatches.Patch(color='gold', alpha=0.5, label='Gold Standard')
red_patch = mpatches.Patch(color='red', alpha=0.5, label='Winner')
plt.legend(handles=[gold_patch, red_patch], loc='best', bbox_to_anchor=(0, 0)) # You can change loc to position the legend
split_fname_dict = {
"testing": "CAID2_test",
"training": "CAID2_train",
"benchmark": "FusionPDB_pLDDT_disorder"
}
split_title_dict = {
"testing": "CAID-2 Disorder Prediction",
"training": "CAID-2 Disorder Prediction",
"benchmark": "FusionPDB_pLDDT Disorder Prediction"
}
ax.set_title(split_title_dict[split])
# Rotate the color bar label
cbar = hm.collections[0].colorbar
cbar.ax.yaxis.set_label_position('right')
cbar.ax.yaxis.set_ticks_position('right')
cbar.set_label('Difference from Gold Standard', rotation=270, labelpad=20) # Rotate 270 degrees and add some padding
# Set tight layout using fig
fig.tight_layout(rect=[0, 0, 0.95, 1]) # Add extra padding on the right side to fit the label
plt.savefig(f"{results_dir}/{split_fname_dict[split]}_heatmap_vs_{gold_standard_model_name}.png")
# Plot AUROC curve of ONE model of interest on its fusion pdb performance
def make_benchmark_auroc_curve(results_dir='.', seq_label_dict=None, path_to_results_of_interest='', model_alias=None):
# Isolate the information for the model we'll be plotting
benchmark_model = path_to_results_of_interest.split('trained_models/')[1].split('/')
benchmark_model_type = benchmark_model[0]
benchmark_model_epoch = np.nan
benchmark_model_hyperparams = None
if len(benchmark_model)==5:
benchmark_model_name = benchmark_model[1]
benchmark_model_epoch = benchmark_model[2].split('epoch')[1]
benchmark_model_hyperparams = benchmark_model[3]
else:
benchmark_model_name = benchmark_model[0]
benchmark_model_hyperparams = benchmark_model[1]
benchmark_model_info = pd.DataFrame(data={
'Model Type': [benchmark_model_type], 'Model Name': [benchmark_model_name], 'Model Epoch': [benchmark_model_epoch]
})
if model_alias is None:
model_alias = benchmark_model_info.apply(lambda row: shorten_model_name(row),axis=1).iloc[0]
color_map = {
model_alias: 'black'
}
method_results = {model_alias: path_to_results_of_interest}
method_results = {k:v for k,v in method_results.items() if v not in [None, '']}
set_font()
plt.figure(figsize=(10,6),dpi=300)
# To store AUROC values and corresponding labels for sorting
roc_data = []
# Read each result file and plot the metrics
for method, path in method_results.items():
df = pd.read_csv(path) # columns = prob_1,labels
# Extract probabilities and labels
prob_1 = ",".join(df['prob_1'].tolist())
df['labels'] = df['sequence'].apply(lambda x: seq_label_dict[x])
labels = "".join(df['labels'].tolist())
prob_1 = [float(x) for x in prob_1.split(",")]
labels = [int(x) for x in list(labels)]
sequences = "".join(df['sequence'].tolist())
assert len(prob_1)==len(labels)==len(sequences)
# Compute ROC curve and ROC area
fpr, tpr, thresholds = roc_curve(labels, prob_1)
roc_auc = auc(fpr, tpr)
# Store data for sorting later
roc_data.append((method, fpr, tpr, roc_auc))
# Sort the methods by AUROC values
roc_data = sorted(roc_data, key=lambda x: x[3], reverse=True)
# Plot sorted ROC curves
for method, fpr, tpr, roc_auc in roc_data:
if method == model_alias:
plt.plot(fpr, tpr, color=color_map[method], lw=2, label=f'{method} ({roc_auc:0.3f})')
else:
plt.plot(fpr, tpr, color=color_map[method], lw=1, alpha=0.7, label=f'{method} ({roc_auc:0.3f})')
# Set other stylistic elements
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.plot([0, 1], [0, 1], color='darkgrey', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
# After plotting the ROC curves, customize the legend
handles, labels = plt.gca().get_legend_handles_labels()
# Create the legend first
legend = plt.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))
# Iterate through the legend's text labels
for text in legend.get_texts():
if model_alias in text.get_text():
text.set_fontweight('bold') # Bold the alias model
plt.tight_layout()
plt.savefig(f'{results_dir}/FusionPDB_pLDDT_disorder_{model_alias}_AUROC_curve.png')
# Plot AUROC curve of ONE model of interest with all the CAID models
def make_auroc_curve(results_dir='.', seq_label_dict=None, seq_ids_dict=None, path_to_results_of_interest='', model_alias=None, path_to_esm_results=None, with_rankings=False):
# Isolate the information for the model we'll be plotting
benchmark_model = path_to_results_of_interest.split('trained_models/')[1].split('/')
benchmark_model_type = benchmark_model[0]
benchmark_model_epoch = np.nan
benchmark_model_hyperparams = None
if len(benchmark_model)==5:
benchmark_model_name = benchmark_model[1]
benchmark_model_epoch = benchmark_model[2].split('epoch')[1]
benchmark_model_hyperparams = benchmark_model[3]
else:
benchmark_model_name = benchmark_model[0]
benchmark_model_hyperparams = benchmark_model[1]
benchmark_model_info = pd.DataFrame(data={
'Model Type': [benchmark_model_type], 'Model Name': [benchmark_model_name], 'Model Epoch': [benchmark_model_epoch]
})
if model_alias is None:
model_alias = benchmark_model_info.apply(lambda row: shorten_model_name(row),axis=1).iloc[0]
color_map = {
'Dispredict3': '#d62727', #1
'flDPnn2': '#ff7f0f', #2
'flDPnn': '#1f77b4', #3
'flDPlr': '#bcbd21', #4
'flDPlr2': '#16becf', #5
'DisoPred': '#1f77b4', #6
'IDP-Fusion': '#d62727', #7
'ESpritz-D': '#8b564c', #8
'DeepIDP-2L': '#e377c2', #9
'disomine': '#e377c2', #10
'DISOPRED3-diso': '#ff892d',
'IUPred3': '#8b564c',
'AlphaFold-rsa': '#2ba02b',
'AlphaFold-pLDDT': '#ff892d',
model_alias: 'black'
}
method_results = {'Dispredict3': 'processed_data/caid2_competition_results/Dispredict3_CAID-2_Disorder_NOX.csv',
'flDPnn2': 'processed_data/caid2_competition_results/flDPnn2_CAID-2_Disorder_NOX.csv',
'flDPnn': 'processed_data/caid2_competition_results/flDPnn_CAID-2_Disorder_NOX.csv',
'flDPlr': 'processed_data/caid2_competition_results/flDPtr_CAID-2_Disorder_NOX.csv', # name doesn't match but this is what it is in raw download
'flDPlr2': 'processed_data/caid2_competition_results/flDPlr2_CAID-2_Disorder_NOX.csv',
'DisoPred': 'processed_data/caid2_competition_results/DisoPred_CAID-2_Disorder_NOX.csv',
'IDP-Fusion': 'processed_data/caid2_competition_results/IDP-Fusion_CAID-2_Disorder_NOX.csv',
'ESpritz-D': 'processed_data/caid2_competition_results/ESpritz-D_CAID-2_Disorder_NOX.csv',
'DeepIDP-2L': 'processed_data/caid2_competition_results/DeepIDP-2L_CAID-2_Disorder_NOX.csv',
'disomine': 'processed_data/caid2_competition_results/disomine_CAID-2_Disorder_NOX.csv',
'DISOPRED3-diso': 'processed_data/caid2_competition_results/DISOPRED3-diso_CAID-2_Disorder_NOX.csv',
'AlphaFold-rsa': 'processed_data/caid2_competition_results/AlphaFold-rsa_CAID-2_Disorder_NOX.csv',
'AlphaFold-pLDDT': 'processed_data/caid2_competition_results/AlphaFold-disorder_CAID-2_Disorder_NOX.csv', # name doesn't match but this is what it is in raw download
'IUPred3': 'processed_data/caid2_competition_results/IUPred3_CAID-2_Disorder_NOX.csv',
model_alias: path_to_results_of_interest
}
if path_to_esm_results is not None:
method_results['ESM-2-650M'] = path_to_esm_results
color_map['ESM-2-650M'] = 'black'
method_results = {k:v for k,v in method_results.items() if v not in [None, '']}
set_font()
plt.figure(figsize=(12,6),dpi=300)
# To store AUROC values and corresponding labels for sorting
merged_preds = pd.DataFrame(data={'sequence':[]})
merged_tpr_fpr = pd.DataFrame(data={'model': [],'fpr':[],'tpr':[]})
roc_data = []
# Read each result file and plot the metrics
for method, path in method_results.items():
df = pd.read_csv(path) # columns = prob_1,labels
merged_preds = pd.merge(merged_preds,
df.rename(columns={'prob_1':f"{method}_prob_1"})[['sequence',f"{method}_prob_1",]],
on=['sequence'],how='outer')
# Extract probabilities and labels
prob_1 = ",".join(df['prob_1'].tolist())
df['labels'] = df['sequence'].apply(lambda x: seq_label_dict[x])
labels = "".join(df['labels'].tolist())
prob_1 = [float(x) for x in prob_1.split(",")]
labels = [int(x) for x in list(labels)]
sequences = "".join(df['sequence'].tolist())
assert len(prob_1)==len(labels)==len(sequences)
# Compute ROC curve and ROC area
fpr, tpr, thresholds = roc_curve(labels, prob_1)
new_tpr_fpr = pd.DataFrame(data={
'model': [method]*len(fpr),
'fpr': fpr, 'tpr': tpr
})
merged_tpr_fpr = pd.concat([merged_tpr_fpr,new_tpr_fpr])
roc_auc = auc(fpr, tpr)
if method==model_alias:
path_to_og_metrics = path_to_results_of_interest.rsplit('/',1)[0]+'/caid_hyperparam_screen_test_metrics.csv'
og_metrics = pd.read_csv(path_to_og_metrics)
roc_auc = og_metrics['AUROC'][0]
# Store data for sorting later
roc_data.append((method, fpr, tpr, roc_auc))
# Save the merged dataframe as source data
merged_preds['labels'] = merged_preds['sequence'].apply(lambda x: seq_label_dict[x])
merged_preds['labels'] = merged_preds['labels'].apply(lambda x: ",".join([str(y) for y in x]))
merged_preds['ids'] = merged_preds['sequence'].apply(lambda x: seq_ids_dict[x])
merged_preds.drop(columns={'sequence'}).to_csv(f"{results_dir}/CAID_prediction_source_data.csv",index=False)
merged_tpr_fpr.to_csv(f"{results_dir}/CAID_fpr_tpr_source_data.csv",index=False)
# Sort the methods by AUROC values
roc_data = sorted(roc_data, key=lambda x: x[3], reverse=True)
# figure out the labels
labels = {method: method for method in method_results}
if with_rankings:
for method in labels:
if method in caid2_model_rankings:
labels[method] = f"{caid2_model_rankings[method]}. {method}"
# Plot sorted ROC curves
for method, fpr, tpr, roc_auc in roc_data:
if method=='ESM-2-650M' and path_to_esm_results is not None:
plt.plot(fpr, tpr, color=color_map[method], lw=2, linestyle='--', label=f'{labels[method]} ({roc_auc:0.3f})')
elif method == model_alias:
plt.plot(fpr, tpr, color=color_map[method], lw=2, label=f'{labels[method]} ({roc_auc:0.3f})')
else:
plt.plot(fpr, tpr, color=color_map[method], lw=1, alpha=0.7, label=f'{labels[method]} ({roc_auc:0.3f})')
# Set other stylistic elements
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.plot([0, 1], [0, 1], color='darkgrey', lw=2, linestyle='--')
plt.xlabel('False Positive Rate', fontsize=22)
plt.ylabel('True Positive Rate', fontsize=22)
plt.title('CAID2 Disorder NOX Dataset: ROC Curve', fontsize=22)
# After plotting the ROC curves, customize the legend
handles, labels = plt.gca().get_legend_handles_labels()
# Create the legend first
legend = plt.legend(handles, labels, loc="center left", bbox_to_anchor=(1.1, 0.5), fontsize=16)
# Iterate through the legend's text labels
for text in legend.get_texts():
if model_alias in text.get_text():
text.set_fontweight('bold') # Bold the alias model
elif (path_to_esm_results is not None) and "ESM-2-650M" in text.get_text():
text.set_fontweight('bold') # Bold ESM if we're comparing to it
plt.tight_layout()
figpath = f'{results_dir}/CAID2_{model_alias}_AUROC_curve.png'
if path_to_esm_results is not None:
figpath = f'{results_dir}/CAID2_{model_alias}_with_ESM_AUROC_curve.png'
plt.savefig(figpath)
def plot_disorder_content_scatter(train_labels, test_labels, benchmark_labels, savepath='splits/disorder_content_scatter.png'):
"""
Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels.
Each labels vector should have ['11110000','0001110',...] format.
"""
# Get train disorder distribution
train_lengths = []
train_frac_disorder = []
for vec in train_labels:
veclist = [int(x) for x in vec]
train_lengths.append(len(veclist))
train_frac_disorder.append(sum(veclist)/len(veclist))
# Get test disorder distribution
test_lengths = []
test_frac_disorder = []
for vec in test_labels:
veclist = [int(x) for x in vec]
test_lengths.append(len(veclist))
test_frac_disorder.append(sum(veclist)/len(veclist))
# Get benchmark disorder distribution
benchmark_lengths = []
benchmark_frac_disorder = []
for vec in benchmark_labels:
veclist = [int(x) for x in vec]
benchmark_lengths.append(len(veclist))
benchmark_frac_disorder.append(sum(veclist)/len(veclist))
# make a plot
set_font()
color_map = {
'train': '#0072B2',
'test': '#E69F00',
'fusion': 'purple'
}
# Plotting
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(train_lengths, train_frac_disorder, color=color_map['train'], label='Train', alpha=0.7)
ax.scatter(test_lengths, test_frac_disorder, color=color_map['test'], label='Test', alpha=0.7)
ax.scatter(benchmark_lengths, benchmark_frac_disorder, color=color_map['fusion'], label='Fusion', alpha=0.7)
# Labels and title
ax.set_xlabel('Length')
ax.set_ylabel('Fraction of Disorder')
ax.set_title('Length vs. Fraction of Disorder for Train, Test, and Benchmark Datasets')
ax.legend()
plt.tight_layout()
plt.savefig(savepath)
def plot_disorder_content_hist(labels, ids, title="data", color="black", savepath='splits/disorder_content_histograms.png'):
"""
Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels.
Each labels vector should have ['11110000','0001110',...] format.
"""
set_font()
# Get disorder distribution
lengths = []
frac_disorder = []
for vec in labels:
veclist = [int(x) for x in vec]
lengths.append(len(veclist))
frac_disorder.append(100*sum(veclist)/len(veclist)) # make it a percent, i like this better
# save the source data
source_data = pd.DataFrame(data={
'ID': ids,
'Percent_Disordered': frac_disorder
})
source_data['Percent_Disordered'] = source_data['Percent_Disordered'].round(3)
source_data.to_csv(savepath.replace(".png","_source_data.csv"),index=False)
fig, ax = plt.subplots(1, 1, figsize=(20, 12))
# Plot histogram for train data
title_fontsize = 70
axislabel_fontsize = 70
tick_fontsize = 50
ax.hist(frac_disorder, bins=20, color=color, alpha=0.7)
ax.set_title(title, fontsize=title_fontsize)
ax.set_xlabel('% Disordered', fontsize=axislabel_fontsize)
ax.set_ylabel('Count', fontsize=axislabel_fontsize)
ax.grid(True)
ax.set_axisbelow(True)
ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
# Calculate the mean and median of the percent coverage
mean_coverage = np.mean(frac_disorder)
median_coverage = np.median(frac_disorder)
# Add vertical line for the mean
ax.axvline(mean_coverage, color='black', linestyle='--', linewidth=2, label=f'Mean: {mean_coverage:.1f}%')
# Add vertical line for the median
ax.axvline(median_coverage, color='black', linestyle='-', linewidth=2, label=f'Median: {median_coverage:.1f}%')
ax.legend(fontsize=50, title_fontsize=50)
plt.tight_layout()
plt.savefig(savepath)
def plot_group_disorder_content_hist(train_labels, test_labels, benchmark_labels, savepath='splits/disorder_content_histograms.png',orient='horizontal'):
"""
Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels.
Each labels vector should have ['11110000','0001110',...] format.
"""
# Get train disorder distribution
train_lengths = []
train_frac_disorder = []
for vec in train_labels:
veclist = [int(x) for x in vec]
train_lengths.append(len(veclist))
train_frac_disorder.append(sum(veclist)/len(veclist))
# Get test disorder distribution
test_lengths = []
test_frac_disorder = []
for vec in test_labels:
veclist = [int(x) for x in vec]
test_lengths.append(len(veclist))
test_frac_disorder.append(sum(veclist)/len(veclist))
# Get benchmark disorder distribution
benchmark_lengths = []
benchmark_frac_disorder = []
for vec in benchmark_labels:
veclist = [int(x) for x in vec]
benchmark_lengths.append(len(veclist))
benchmark_frac_disorder.append(sum(veclist)/len(veclist))
# make a plot
set_font()
color_map = {
'train': '#0072B2',
'test': '#E69F00',
'fusion': 'mediumpurple'
}
# Create a 1x3 subplot (1 row, 3 columns) or 3x1
if orient=='horizontal':
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=False)
if orient=='vertical':
fig, axes = plt.subplots(3, 1, figsize=(5, 15), sharey=False)
# Plot histogram for train data
title_fontsize = 26
axislabel_fontsize = 26
tick_fontsize = 16
axes[0].hist(train_frac_disorder, bins=20, color=color_map['train'], alpha=0.7)
axes[0].set_title('CAID2 Train', fontsize=title_fontsize)
if orient=="horizontal":
axes[0].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize)
axes[0].set_ylabel('Frequency', fontsize=axislabel_fontsize)
axes[0].grid(True)
axes[0].set_axisbelow(True)
axes[0].tick_params(axis='both', which='major', labelsize=tick_fontsize)
# Plot histogram for test data
axes[1].hist(test_frac_disorder, bins=20, color=color_map['test'], alpha=0.7)
axes[1].set_title('CAID2 Test',fontsize=title_fontsize)
if orient=="horizontal":
axes[1].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize)
if orient=="vertical":
axes[1].set_ylabel('Frequency', fontsize=axislabel_fontsize)
axes[1].grid(True)
axes[1].set_axisbelow(True)
axes[1].tick_params(axis='both', which='major', labelsize=tick_fontsize)
# Plot histogram for benchmark (fusion) data
axes[2].hist(benchmark_frac_disorder, bins=20, color=color_map['fusion'], alpha=0.7)
axes[2].set_title('Fusion Oncoproteins',fontsize=title_fontsize)
axes[2].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize)
if orient=="vertical":
axes[2].set_ylabel('Frequency', fontsize=axislabel_fontsize)
axes[2].grid(True)
axes[2].set_axisbelow(True)
axes[2].tick_params(axis='both', which='major', labelsize=tick_fontsize)
plt.tight_layout()
plt.savefig(savepath)
def categorize_plddt(values):
categories = {
"<= 50": sum(1 for x in values if x <= 50),
"50-70": sum(1 for x in values if 50 < x <= 70),
"70-90": sum(1 for x in values if 70 < x <= 90),
"> 90": sum(1 for x in values if x > 90)
}
return categories
def plot_fusion_sequence_pLDDT_left_to_right(fusion_structure_data, fusiongene, save_path=''):
"""
Plot each amino acid in the sequence as a separate colored bar based on pLDDT values.
"""
set_font()
# Filter for specific fusion data and preprocess
df_of_interest = fusion_structure_data[fusion_structure_data['FusionGene'] == fusiongene].copy()
df_of_interest['Fusion_AA_pLDDTs'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [float(i) for i in x.split(',')])
df_of_interest['Label'] = df_of_interest['Fusion_Length'].astype(str) + 'AAs'
# Sort data by Fusion_Length
df_of_interest = df_of_interest.sort_values(by='Fusion_Length', ascending=True).reset_index(drop=True)
# Define colors for each pLDDT range
category_colors = {"<= 50": "#f27842", "50-70": "#f8d514", "70-90": "#60c1e8", "> 90": "#004ecb"}
# Helper function to get color based on pLDDT
def get_color(pLDDT):
if pLDDT > 90:
return category_colors["> 90"]
elif pLDDT > 70:
return category_colors["70-90"]
elif pLDDT > 50:
return category_colors["50-70"]
else:
return category_colors["<= 50"]
# Start plotting each sequence with colored bars
fig, ax = plt.subplots(figsize=(10, 6))
if len(df_of_interest)<3:
fig, ax = plt.subplots(figsize=(10, 2))
average_plddt = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_pLDDT']))
df_of_interest['Fusion_AA_colors'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [get_color(plddt) for plddt in x])
df_of_interest['Fusion_pLDDT_color'] = df_of_interest['Fusion_pLDDT'].apply(lambda plddt: get_color(plddt))
# just save the columns needed for the plot
df_of_interest[['FusionGene','seq_id','Fusion_Length','Fusion_pLDDT','Fusion_AA_pLDDTs','Fusion_AA_colors','Fusion_pLDDT_color',
'top_hg_UniProtID','top_hg_UniProt_isoform','top_hg_UniProt_fus_indices',
'top_tg_UniProtID','top_tg_UniProt_isoform','top_tg_UniProt_fus_indices']].to_csv(f"{save_path}/plddt_sequence_{fusiongene}_source_data.csv",index=False)
for idx, row in df_of_interest.iterrows():
pLDDT_values = row['Fusion_AA_pLDDTs']
colors = [get_color(plddt) for plddt in pLDDT_values]
# Plot each amino acid in the sequence with the respective color
ax.bar(range(len(pLDDT_values)),
[0.7] * len(pLDDT_values), color=colors, edgecolor='none',
bottom=idx - 0.7 / 2) # Centering each row at idx
labels = df_of_interest['Label'].tolist()
# Annotate each bar with the Fusion_pLDDT value on the right, colored by PLDDT category
for idx, label in enumerate(labels):
avg_plddt_value = average_plddt[label]
# Determine color based on the PLDDT category
if avg_plddt_value > 90:
color = '#004ecb'
elif avg_plddt_value > 70:
color = "#60c1e8"
elif avg_plddt_value > 50:
color = '#f8d514'
else:
color = '#f27842'
# Annotate with the determined color
if len(df_of_interest)>10:
markersize = 10
elif len(df_of_interest)>5:
markersize = 16
else:
markersize=12
ax.plot(1.02*max(df_of_interest['Fusion_Length']),
idx, marker='o', color="black", markersize=markersize, markerfacecolor=color, markeredgewidth=2)
# Add breakpoint box - make sure we actually HAVE one of each
hg_indices, tg_indices = None, None
if not(type(df_of_interest['top_hg_UniProt_fus_indices'][idx])==float):
hg_indices = [int(x) for x in df_of_interest['top_hg_UniProt_fus_indices'][idx].split(',')]
if not(type(df_of_interest['top_tg_UniProt_fus_indices'][idx])==float):
tg_indices = [int(x) for x in df_of_interest['top_tg_UniProt_fus_indices'][idx].split(',')]
print(hg_indices, tg_indices)
if (hg_indices is not None) and (tg_indices is not None):
box_start = min(hg_indices[-1],tg_indices[0])
box_end = max(hg_indices[-1],tg_indices[0])
elif hg_indices is not None:
box_start, box_end = hg_indices[-1], hg_indices[-1]
elif tg_indices is not None:
box_start, box_end = tg_indices[0], tg_indices[0]
print(f"box indices for structure {idx}, fusion gene {fusiongene}", box_start, box_end)
# Plot the rectangle, making it slightly larger than the rest of the bar
rect = patches.Rectangle((box_start, idx - 0.7 / 2), box_end-box_start, 0.7, linewidth=2, edgecolor='black', facecolor='none')
ax.add_patch(rect)
# Customize plot
ax.set_yticks([]) # Hide y-axis ticks
ax.set_yticklabels([]) # Hide y-axis labels
ax.set_ylim(-0.5, len(df_of_interest) - 0.5) # reduce white space at top
ax.set_xlabel("Amino Acid Sequence (ordered)", fontsize=14)
# Customize x-axis for labeling
ax.set_xlim(left=0) # Start x-axis at 0 to make bars flush left
ax.set_xlabel("Amino Acid Sequence (ordered)", fontsize=14)
ax.tick_params(axis='x', labelsize=30)
plt.title(f"{fusiongene} pLDDT Distribution by Amino Acid Sequence", fontsize=16)
plt.tight_layout()
# Save figure
fusiongene_savename = fusiongene.replace("::","-")
plt.savefig(f"{save_path}/plddt_sequence_{fusiongene_savename}.png", dpi=300)
plt.show()
def plot_favorite_fusion_pLDDT_distribution(fusion_structure_data, fusiongene, save_path=''):
"""
Make a stacked bar chart of the pLDDT distribution
"""
set_font()
# Filter for EWSR1::FLI1 fusion data and preprocess
df_of_interest = fusion_structure_data[fusion_structure_data['FusionGene'] == fusiongene].copy()
df_of_interest['Fusion_AA_pLDDTs'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [float(i) for i in x.split(',')])
df_of_interest['Label'] = df_of_interest['Fusion_Length'].astype(str) + 'AAs'
# Sort data by Fusion_Length
df_of_interest = df_of_interest.sort_values(by='Fusion_Length', ascending=True).reset_index(drop=True)
# Convert to dictionary format
data_dict = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_AA_pLDDTs']))
average_plddt = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_pLDDT']))
# Categorize each structure
categorized_data = {structure: categorize_plddt(plddt_values) for structure, plddt_values in data_dict.items()}
# Extract counts for each category
labels = list(categorized_data.keys())
categories = ["<= 50", "50-70", "70-90", "> 90"]
counts = {cat: [categorized_data[structure][cat] for structure in labels] for cat in categories}
# Define colors for each category
category_colors = {"<= 50": "#f27842", "50-70": "#f8d514", "70-90": "#60c1e8", "> 90": "#004ecb"}
# Re-categorize PLDDT values for the bar chart
categorized_data = {structure: categorize_plddt(plddt_values) for structure, plddt_values in data_dict.items()}
labels = list(categorized_data.keys())
counts = {cat: [categorized_data[structure][cat] for structure in labels] for cat in categories}
# Plotting the horizontal stacked bar chart with annotations for 'Fusion_pLDDT' values
fig, ax = plt.subplots(figsize=(10, 6))
if len(data_dict)<3:
fig, ax = plt.subplots(figsize=(10, 2))
bottom = np.zeros(len(labels))
# Stack each category horizontally
for cat in categories:
ax.barh(labels, counts[cat], label=cat, color=category_colors[cat], left=bottom)
bottom += counts[cat] # Update the left position for the next stack
# Annotate each bar with the Fusion_pLDDT value on the right, colored by PLDDT category
for idx, label in enumerate(labels):
avg_plddt_value = average_plddt[label]
# Determine color based on the PLDDT category
if avg_plddt_value > 90:
color = '#004ecb'
elif avg_plddt_value > 70:
color = "#60c1e8"
elif avg_plddt_value > 50:
color = '#f8d514'
else:
color = '#f27842'
# Annotate with the determined color
#ax.text(bottom[idx] + 1, idx, f"{avg_plddt_value:.2f}", va='center', ha='left', color="black", fontsize=18, fontweight='bold')
if len(df_of_interest)>10:
markersize = 10
elif len(df_of_interest)>5:
markersize = 16
else:
markersize=12
ax.plot(bottom[idx] + .02*max(df_of_interest['Fusion_Length']), idx, marker='s', color="black", markersize=markersize, markerfacecolor=color, markeredgewidth=2)
# Add labels and legend
#ax.set_xlim([0,max(df_of_interest['Fusion_Length'])*1.0])
#ax.set_ylabel("Structures")
# Save original ticks before changing label size
#ax.tick_params(axis='x', labelsize=16)
#original_xticks = ax.get_xticks()
# Set ticks explicitly to avoid automatic adjustment
#ax.set_xticks(original_xticks)
#ax.set_xlabel("Length",fontsize=40)
ax.tick_params(axis='x', labelsize=30)
#ax.tick_params(axis='y', labelsize=16)
ax.tick_params(axis='y', left=False, labelleft=False)
#ax.set_title(f"{fusiongene} pLDDT Distribution")
#ax.legend(title="pLDDT Ranges", fontsize=16, bbox_to_anchor=(1, 1), title_fontsize=16)
plt.tight_layout()
fusiongene_savename = fusiongene.replace("::","-")
plt.savefig(f"{save_path}/plddt_dist_{fusiongene_savename}.png",dpi=300)
def make_all_favorite_fusion_pLDDT_plots(favorite_fusions,left_to_right=True):
fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv')
swissprot_top_alignments = pd.read_csv("../../data/blast/blast_outputs/swissprot_top_alignments.csv")
fuson_db = pd.read_csv("../../data/fuson_db.csv")
seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
fusion_structure_data['seq_id'] = fusion_structure_data['Fusion_Seq'].map(seq_id_dict)
fusion_structure_data = pd.merge(
fusion_structure_data,
swissprot_top_alignments,
on="seq_id",
how="left"
)
for x in favorite_fusions:
if left_to_right:
plot_fusion_sequence_pLDDT_left_to_right(fusion_structure_data, x, save_path='processed_data/figures/fusion_disorder')
else:
plot_favorite_fusion_pLDDT_distribution(fusion_structure_data, x, save_path='processed_data/figures/fusion_disorder')
def prep_data_for_ht_disorder_comparison():
ht_structure_data = pd.read_csv('processed_data/fusionpdb/heads_tails_structural_data.csv')
fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv')
fusion_heads_and_tails = pd.read_csv('processed_data/fusionpdb/fusion_heads_and_tails.csv')
all_hts_with_structures = ht_structure_data['UniProtID'].unique().tolist()
fuson_ht_db = pd.read_csv('../../data/blast/fuson_ht_db.csv')[['seq_id','aa_seq','fusiongenes','hgUniProt','tgUniProt']]
merge = pd.merge(
fuson_ht_db.rename(columns={'aa_seq':'Fusion_Seq'}),
fusion_structure_data[['FusionGID', 'Fusion_Seq','Fusion_pLDDT','Fusion_AA_pLDDTs']],
on='Fusion_Seq',
how='right'
)
# now merge again
merge['hgUniProt'] = merge['hgUniProt'].apply(lambda x: x.split(','))
merge['tgUniProt'] = merge['tgUniProt'].apply(lambda x: x.split(','))
merge = merge.explode('hgUniProt')
merge = merge.explode('tgUniProt')
merge = merge.loc[
merge['hgUniProt'].isin(all_hts_with_structures) &
merge['tgUniProt'].isin(all_hts_with_structures)
].reset_index(drop=True)
merge = pd.merge(
merge,
ht_structure_data.rename(columns=
{'UniProtID':'hgUniProt',
'Avg pLDDT': 'hg_pLDDT',
'All pLDDTs': 'hg_AA_pLDDTs',
'Seq': 'hg_seq'}),
on='hgUniProt',
how='inner'
)
merge = pd.merge(
merge,
ht_structure_data.rename(columns=
{'UniProtID':'tgUniProt',
'Avg pLDDT': 'tg_pLDDT',
'All pLDDTs': 'tg_AA_pLDDTs',
'Seq': 'tg_seq'}),
on='tgUniProt',
how='inner'
)
merge = merge.loc[merge['hg_AA_pLDDTs'].notna()]
merge = merge.loc[merge['tg_AA_pLDDTs'].notna()].reset_index(drop=True)
# finally, calcualte label
merge['hg_label'] = merge['hg_AA_pLDDTs'].apply(lambda x: x.split(','))
merge['hg_label'] = merge['hg_label'].apply(lambda x: [float(y) for y in x])
merge['hg_label'] = merge['hg_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x])
merge['hg_label'] = merge['hg_label'].apply(lambda x: ''.join(x))
merge['tg_label'] = merge['tg_AA_pLDDTs'].apply(lambda x: x.split(','))
merge['tg_label'] = merge['tg_label'].apply(lambda x: [float(y) for y in x])
merge['tg_label'] = merge['tg_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x])
merge['tg_label'] = merge['tg_label'].apply(lambda x: ''.join(x))
merge['fusion_label'] = merge['Fusion_AA_pLDDTs'].apply(lambda x: x.split(','))
merge['fusion_label'] = merge['fusion_label'].apply(lambda x: [float(y) for y in x])
merge['fusion_label'] = merge['fusion_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x])
merge['fusion_label'] = merge['fusion_label'].apply(lambda x: ''.join(x))
return merge
def apply_plddt_thresh(y):
if y < 68.8:
return '1'
else:
return '0'
def plot_fusion_stats_boxplots(data, save_path="fusion_disorder_boxplots.png"):
set_font()
# Create box plots
plt.figure(figsize=(6, 5))
# for ones that are 100% disordered, AUROC was NaN, so drop these
box = plt.boxplot([data[col].dropna() for col in data.columns], labels=data.columns, patch_artist=True)
# Set color of each box plot
for patch in box['boxes']:
patch.set_facecolor('#ff68b4')
patch.set_edgecolor('#ff68b4')
# Customize other elements if needed
#for whisker in box['whiskers']:
#whisker.set_color('#ff68b4')
#for cap in box['caps']:
#cap.set_color('#ff68b4')
for median in box['medians']:
median.set_color('black')
# Add labels and title
#plt.xlabel('Metrics')
#plt.ylabel('Values')
plt.title(f"Per-Residue Disorder (n={len(data)})",fontsize=22)
plt.xticks(rotation=20,fontsize=22)
plt.yticks(fontsize=22)
# Show plot
plt.tight_layout()
plt.show()
plt.savefig(save_path,dpi=300)
def plot_fusion_frac_disorder_r2(actual_values, predicted_values, save_path="fusion_pred_disorder_r2.png"):
set_font()
plt.figure(figsize=(6, 6))
r2 = r2_score(actual_values, predicted_values)
#sns.kdeplot(actual_values, label="Actual Values", shade=True)
#sns.kdeplot(predicted_values, label="Predicted Values", shade=True)
plt.scatter(actual_values, predicted_values, alpha=0.5, label=f"Predictions", color="#ff68b4")
plt.plot([min(actual_values), max(actual_values)], [min(actual_values), max(actual_values)], 'k--', label='Ideal Fit')
plt.text(0, 92, f"$R^2$={r2:.2f}", fontsize=32)
# Adjusting font sizes and setting font properties
plt.xlabel(f'AlphaFold-pLDDT',size=32)
plt.ylabel(f'FusOn-pLM-Diso',size=32)
plt.title(f"% Disordered (n={len(actual_values)})",size=32)
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)
#plt.xlabel("Values")
#plt.ylabel("Density")
#plt.title(f"Density Plot of Actual vs Predicted Values (R^2 = {r2:.2f})")
plt.legend(prop={'size': 16})
plt.tight_layout()
plt.show()
plt.savefig(save_path, dpi=300)
def main():
set_font()
#output_dir = "results/test"
output_dir = "results/final"
seq_label_dict = pd.read_csv('splits/test_df.csv')
seq_ids_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['IDs']))
seq_label_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['Label']))
best_caid_model_results = pd.read_csv(f"{output_dir}/best_caid_model_results.csv")
make_auroc_curve(results_dir=output_dir,
seq_label_dict=seq_label_dict,
seq_ids_dict=seq_ids_dict,
path_to_results_of_interest="trained_models/fuson_plm/best/caid_hyperparam_screen_test_probs.csv",
model_alias="FusOn-pLM",
path_to_esm_results="trained_models/esm2_t33_650M_UR50D/best/caid_hyperparam_screen_test_probs.csv",
with_rankings=True)
caid2_test_data = pd.read_csv(f"splits/splits.csv")
caid2_test_data = caid2_test_data.loc[caid2_test_data['Split']=='Test']
caid2_test_labels = caid2_test_data['Label'].tolist()
caid2_test_ids = caid2_test_data['IDs'].tolist()
# fusions, heads, and tails
fusion_ht_data = prep_data_for_ht_disorder_comparison()
os.makedirs("processed_data/figures",exist_ok=True)
head_data = fusion_ht_data.drop_duplicates(['hg_seq']).reset_index(drop=True)
head_labels = head_data['hg_label'].tolist()
head_ids = head_data['hgUniProt'].tolist()
tail_data = fusion_ht_data.drop_duplicates(['tg_seq']).reset_index(drop=True)
tail_labels = tail_data['tg_label'].tolist()
tail_ids = tail_data['tgUniProt'].tolist()
fusion_data = fusion_ht_data.drop_duplicates(['Fusion_Seq']).reset_index(drop=True)
fusion_labels = fusion_data['fusion_label'].tolist()
fusion_ids = fusion_data['seq_id'].tolist()
plt.rc('text', usetex=False)
math_part = r"$n$"
os.makedirs("processed_data/figures/histograms",exist_ok=True)
plot_disorder_content_hist(caid2_test_labels, caid2_test_ids, title=f"CAID2 Disorder-NOX ({math_part}={len(caid2_test_labels):,})", color="black", savepath='processed_data/figures/histograms/disorder_nox_histogram.png')
plot_disorder_content_hist(head_labels, head_ids, title=f"Head Proteins ({math_part}={len(head_labels):,})", color="#df8385", savepath='processed_data/figures/histograms/heads_histogram.png')
plot_disorder_content_hist(tail_labels, tail_ids, title=f"Tail Proteins ({math_part}={len(tail_labels):,})", color="#6ea4da", savepath='processed_data/figures/histograms/tails_histogram.png')
plot_disorder_content_hist(fusion_labels, fusion_ids, title=f"Fusion Oncoproteins ({math_part}={len(fusion_labels):,})", color="mediumpurple", savepath='processed_data/figures/histograms/fusions_histogram.png')
os.makedirs("processed_data/figures/fusion_disorder",exist_ok=True)
make_all_favorite_fusion_pLDDT_plots([
"EWSR1::FLI1",
"PAX3::FOXO1",
"EML4::ALK",
"SS18::SSX1"],
left_to_right=True)
if __name__ == "__main__":
main()