import matplotlib.pyplot as plt |
import seaborn as sns |
import pandas as pd |
import numpy as np |
import os |
import matplotlib.colors as mcolors |
import matplotlib.patches as mpatches |
from matplotlib import font_manager |
import matplotlib.patches as patches |
from sklearn.metrics import roc_curve, auc, r2_score |
from fuson_plm.utils.visualizing import set_font |
global caid2_winners, caid2_model_rankings |
caid2_winners = pd.DataFrame(data= |
{ |
'Model Name': ['Dispredict3','flDPnn2','flDPnn','flDPlr','flDPlr2','DisoPred', |
'IDP-Fusion','ESpritz-D','DeepIDP-2L','disomine','DISOPRED3-diso','IUPred3', |
'AlphaFold-rsa','AlphaFold-pLDDT'], |
'AUROC': [0.838,0.836,0.833,0.827,0.821,0.821, |
0.818,0.802,0.800,0.797,0.692,0.755,0.747,0.695], |
}) |
caid2_winners['Model Type'] = ['caid2_competition']*len(caid2_winners) |
caid2_winners['Model Epoch'] = [np.nan]*len(caid2_winners) |
caid2_model_rankings = { |
'Dispredict3': 1, |
'flDPnn2': 2, |
'flDPnn': 3, |
'flDPlr': 4, |
'flDPlr2': 5, |
'DisoPred': 6, |
'IDP-Fusion': 7, |
'ESpritz-D': 8, |
'DeepIDP-2L': 9, |
'disomine': 10, |
'DISOPRED3-diso': 35, |
'IUPred3': 21, |
'AlphaFold-rsa': 24, |
'AlphaFold-pLDDT': 34 |
} |
def lengthen_model_name(row): |
model_type = row['Model Type'] |
name = row['Model Name'] |
epoch = row['Model Epoch'] |
if 'esm' in name: |
return name |
if 'puncta' in name: |
return name |
if model_type=='caid2_competition': |
return name |
return f'{name}_e{epoch}' |
def shorten_model_name(row): |
model_type = row['Model Type'] |
name = row['Model Name'] |
epoch = row['Model Epoch'] |
if 'esm' in name: |
return 'ESM-2-650M' |
if model_type=='caid2_competition': |
return name |
if 'snp_' in name: |
prob_type = 'snp' |
elif 'uniform_' in name: |
prob_type = 'uni' |
layers = name.split('layers')[0].split('_')[-1] |
maskrate = name.split('mask')[1].split('-', 1)[0] |
kqv_tag = name.split('layers_')[1].split('_')[0] |
dt = name.split('mask')[1].split('-', 1)[1] |
return f'{prob_type}_{layers}L_{kqv_tag}_mask{maskrate}_{dt}_e{epoch}' |
def make_heatmap(df, results_dir='.', gold_standard_model_name="esm2_t33_650M_UR50D",split="test",thresh=None,ax=None): |
set_font() |
columns_to_compare = ['AUROC'] |
if not(split=="benchmark"): |
df = pd.concat([df,caid2_winners]) |
df['Model Epoch'] = df['Model Epoch'].apply(lambda x: str(int(x)) if not(np.isnan(x)) else '') |
df['Short Model Name'] = df.apply(lambda row: shorten_model_name(row),axis=1) |
df['Full Model Name'] = df.apply(lambda row: lengthen_model_name(row), axis=1) |
gold_standard = df[df['Full Model Name'] == gold_standard_model_name].reset_index(drop=True).iloc[0] |
gold_standard_short_model_name = df[df['Full Model Name'] == gold_standard_model_name]['Short Model Name'].item() |
heatmap_data = df[['Model Type','Short Model Name','Full Model Name'] + columns_to_compare].copy() |
heatmap_data['is_gold_standard'] = (heatmap_data['Full Model Name'] == gold_standard_model_name).astype(int) |
heatmap_data = heatmap_data.sort_values(by=['is_gold_standard','Model Type','AUROC'], ascending=[False,True,False]).reset_index(drop=True).drop(columns=['is_gold_standard']) |
original_values = heatmap_data[columns_to_compare].copy() |
for col in columns_to_compare: |
heatmap_data[col] = heatmap_data[col] - gold_standard[col] |
cmap = sns.color_palette("coolwarm", as_cmap=True) |
if ax is None: |
tallsize = max(8, 8 +.25*(len(heatmap_data)-26)) |
fig, ax = plt.subplots(1, 1, figsize=(8, tallsize), dpi=300) |
hm = sns.heatmap(heatmap_data.set_index('Short Model Name').drop(columns=['Model Type','Full Model Name']), |
annot=False, fmt='', cmap=cmap, center=0, |
cbar_kws={'label': 'Difference from Gold Standard'}) |
ax.set_yticklabels(heatmap_data['Short Model Name'], rotation=0, fontsize=12) |
ax.set_ylabel("Short Model Name", labelpad=20) |
for i in range(original_values.shape[0]): |
for j in range(original_values.shape[1]): |
value = original_values.iloc[i, j] |
if value > gold_standard[columns_to_compare[j]]: |
ax.text(j + 0.5, i + 0.5, f'{value:.3f}', ha='center', va='center', fontweight='bold', color='black') |
else: |
ax.text(j + 0.5, i + 0.5, f'{value:.3f}', ha='center', va='center', color='black') |
model_type_series = heatmap_data['Model Type'].values |
last_index = 0 |
labels_positions = [] |
for i in range(1, len(model_type_series)): |
if model_type_series[i] != model_type_series[i - 1]: |
hm.axhline(i, color='white', linewidth=8) |
labels_positions.append((last_index + i) / 2) |
last_index = i |
labels_positions.append((last_index + len(model_type_series)) / 2) |
for ytick, model_name in enumerate(heatmap_data['Short Model Name']): |
if model_name == gold_standard_short_model_name: |
label = ax.get_yticklabels()[ytick] |
label.set_bbox(dict(facecolor='gold', alpha=0.5, edgecolor='gold')) |
if model_name != gold_standard_short_model_name: |
auroc_value = original_values.loc[ytick, 'AUROC'] |
if (auroc_value > gold_standard['AUROC']): |
label = ax.get_yticklabels()[ytick] |
label.set_bbox(dict(facecolor='red', alpha=0.3, edgecolor='red')) |
gold_patch = mpatches.Patch(color='gold', alpha=0.5, label='Gold Standard') |
red_patch = mpatches.Patch(color='red', alpha=0.5, label='Winner') |
plt.legend(handles=[gold_patch, red_patch], loc='best', bbox_to_anchor=(0, 0)) |
split_fname_dict = { |
"testing": "CAID2_test", |
"training": "CAID2_train", |
"benchmark": "FusionPDB_pLDDT_disorder" |
} |
split_title_dict = { |
"testing": "CAID-2 Disorder Prediction", |
"training": "CAID-2 Disorder Prediction", |
"benchmark": "FusionPDB_pLDDT Disorder Prediction" |
} |
ax.set_title(split_title_dict[split]) |
cbar = hm.collections[0].colorbar |
cbar.ax.yaxis.set_label_position('right') |
cbar.ax.yaxis.set_ticks_position('right') |
cbar.set_label('Difference from Gold Standard', rotation=270, labelpad=20) |
fig.tight_layout(rect=[0, 0, 0.95, 1]) |
plt.savefig(f"{results_dir}/{split_fname_dict[split]}_heatmap_vs_{gold_standard_model_name}.png") |
def make_benchmark_auroc_curve(results_dir='.', seq_label_dict=None, path_to_results_of_interest='', model_alias=None): |
benchmark_model = path_to_results_of_interest.split('trained_models/')[1].split('/') |
benchmark_model_type = benchmark_model[0] |
benchmark_model_epoch = np.nan |
benchmark_model_hyperparams = None |
if len(benchmark_model)==5: |
benchmark_model_name = benchmark_model[1] |
benchmark_model_epoch = benchmark_model[2].split('epoch')[1] |
benchmark_model_hyperparams = benchmark_model[3] |
else: |
benchmark_model_name = benchmark_model[0] |
benchmark_model_hyperparams = benchmark_model[1] |
benchmark_model_info = pd.DataFrame(data={ |
'Model Type': [benchmark_model_type], 'Model Name': [benchmark_model_name], 'Model Epoch': [benchmark_model_epoch] |
}) |
if model_alias is None: |
model_alias = benchmark_model_info.apply(lambda row: shorten_model_name(row),axis=1).iloc[0] |
color_map = { |
model_alias: 'black' |
} |
method_results = {model_alias: path_to_results_of_interest} |
method_results = {k:v for k,v in method_results.items() if v not in [None, '']} |
set_font() |
plt.figure(figsize=(10,6),dpi=300) |
roc_data = [] |
for method, path in method_results.items(): |
df = pd.read_csv(path) |
prob_1 = ",".join(df['prob_1'].tolist()) |
df['labels'] = df['sequence'].apply(lambda x: seq_label_dict[x]) |
labels = "".join(df['labels'].tolist()) |
prob_1 = [float(x) for x in prob_1.split(",")] |
labels = [int(x) for x in list(labels)] |
sequences = "".join(df['sequence'].tolist()) |
assert len(prob_1)==len(labels)==len(sequences) |
fpr, tpr, thresholds = roc_curve(labels, prob_1) |
roc_auc = auc(fpr, tpr) |
roc_data.append((method, fpr, tpr, roc_auc)) |
roc_data = sorted(roc_data, key=lambda x: x[3], reverse=True) |
for method, fpr, tpr, roc_auc in roc_data: |
if method == model_alias: |
plt.plot(fpr, tpr, color=color_map[method], lw=2, label=f'{method} ({roc_auc:0.3f})') |
else: |
plt.plot(fpr, tpr, color=color_map[method], lw=1, alpha=0.7, label=f'{method} ({roc_auc:0.3f})') |
plt.xlim([0.0, 1.0]) |
plt.ylim([0.0, 1.05]) |
plt.plot([0, 1], [0, 1], color='darkgrey', lw=2, linestyle='--') |
plt.xlabel('False Positive Rate') |
plt.ylabel('True Positive Rate') |
plt.title('Receiver Operating Characteristic (ROC) Curve') |
handles, labels = plt.gca().get_legend_handles_labels() |
legend = plt.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5)) |
for text in legend.get_texts(): |
if model_alias in text.get_text(): |
text.set_fontweight('bold') |
plt.tight_layout() |
plt.savefig(f'{results_dir}/FusionPDB_pLDDT_disorder_{model_alias}_AUROC_curve.png') |
def make_auroc_curve(results_dir='.', seq_label_dict=None, seq_ids_dict=None, path_to_results_of_interest='', model_alias=None, path_to_esm_results=None, with_rankings=False): |
benchmark_model = path_to_results_of_interest.split('trained_models/')[1].split('/') |
benchmark_model_type = benchmark_model[0] |
benchmark_model_epoch = np.nan |
benchmark_model_hyperparams = None |
if len(benchmark_model)==5: |
benchmark_model_name = benchmark_model[1] |
benchmark_model_epoch = benchmark_model[2].split('epoch')[1] |
benchmark_model_hyperparams = benchmark_model[3] |
else: |
benchmark_model_name = benchmark_model[0] |
benchmark_model_hyperparams = benchmark_model[1] |
benchmark_model_info = pd.DataFrame(data={ |
'Model Type': [benchmark_model_type], 'Model Name': [benchmark_model_name], 'Model Epoch': [benchmark_model_epoch] |
}) |
if model_alias is None: |
model_alias = benchmark_model_info.apply(lambda row: shorten_model_name(row),axis=1).iloc[0] |
color_map = { |
'Dispredict3': '#d62727', |
'flDPnn2': '#ff7f0f', |
'flDPnn': '#1f77b4', |
'flDPlr': '#bcbd21', |
'flDPlr2': '#16becf', |
'DisoPred': '#1f77b4', |
'IDP-Fusion': '#d62727', |
'ESpritz-D': '#8b564c', |
'DeepIDP-2L': '#e377c2', |
'disomine': '#e377c2', |
'DISOPRED3-diso': '#ff892d', |
'IUPred3': '#8b564c', |
'AlphaFold-rsa': '#2ba02b', |
'AlphaFold-pLDDT': '#ff892d', |
model_alias: 'black' |
} |
method_results = {'Dispredict3': 'processed_data/caid2_competition_results/Dispredict3_CAID-2_Disorder_NOX.csv', |
'flDPnn2': 'processed_data/caid2_competition_results/flDPnn2_CAID-2_Disorder_NOX.csv', |
'flDPnn': 'processed_data/caid2_competition_results/flDPnn_CAID-2_Disorder_NOX.csv', |
'flDPlr': 'processed_data/caid2_competition_results/flDPtr_CAID-2_Disorder_NOX.csv', |
'flDPlr2': 'processed_data/caid2_competition_results/flDPlr2_CAID-2_Disorder_NOX.csv', |
'DisoPred': 'processed_data/caid2_competition_results/DisoPred_CAID-2_Disorder_NOX.csv', |
'IDP-Fusion': 'processed_data/caid2_competition_results/IDP-Fusion_CAID-2_Disorder_NOX.csv', |
'ESpritz-D': 'processed_data/caid2_competition_results/ESpritz-D_CAID-2_Disorder_NOX.csv', |
'DeepIDP-2L': 'processed_data/caid2_competition_results/DeepIDP-2L_CAID-2_Disorder_NOX.csv', |
'disomine': 'processed_data/caid2_competition_results/disomine_CAID-2_Disorder_NOX.csv', |
'DISOPRED3-diso': 'processed_data/caid2_competition_results/DISOPRED3-diso_CAID-2_Disorder_NOX.csv', |
'AlphaFold-rsa': 'processed_data/caid2_competition_results/AlphaFold-rsa_CAID-2_Disorder_NOX.csv', |
'AlphaFold-pLDDT': 'processed_data/caid2_competition_results/AlphaFold-disorder_CAID-2_Disorder_NOX.csv', |
'IUPred3': 'processed_data/caid2_competition_results/IUPred3_CAID-2_Disorder_NOX.csv', |
model_alias: path_to_results_of_interest |
} |
if path_to_esm_results is not None: |
method_results['ESM-2-650M'] = path_to_esm_results |
color_map['ESM-2-650M'] = 'black' |
method_results = {k:v for k,v in method_results.items() if v not in [None, '']} |
set_font() |
plt.figure(figsize=(12,6),dpi=300) |
merged_preds = pd.DataFrame(data={'sequence':[]}) |
merged_tpr_fpr = pd.DataFrame(data={'model': [],'fpr':[],'tpr':[]}) |
roc_data = [] |
for method, path in method_results.items(): |
df = pd.read_csv(path) |
merged_preds = pd.merge(merged_preds, |
df.rename(columns={'prob_1':f"{method}_prob_1"})[['sequence',f"{method}_prob_1",]], |
on=['sequence'],how='outer') |
prob_1 = ",".join(df['prob_1'].tolist()) |
df['labels'] = df['sequence'].apply(lambda x: seq_label_dict[x]) |
labels = "".join(df['labels'].tolist()) |
prob_1 = [float(x) for x in prob_1.split(",")] |
labels = [int(x) for x in list(labels)] |
sequences = "".join(df['sequence'].tolist()) |
assert len(prob_1)==len(labels)==len(sequences) |
fpr, tpr, thresholds = roc_curve(labels, prob_1) |
new_tpr_fpr = pd.DataFrame(data={ |
'model': [method]*len(fpr), |
'fpr': fpr, 'tpr': tpr |
}) |
merged_tpr_fpr = pd.concat([merged_tpr_fpr,new_tpr_fpr]) |
roc_auc = auc(fpr, tpr) |
if method==model_alias: |
path_to_og_metrics = path_to_results_of_interest.rsplit('/',1)[0]+'/caid_hyperparam_screen_test_metrics.csv' |
og_metrics = pd.read_csv(path_to_og_metrics) |
roc_auc = og_metrics['AUROC'][0] |
roc_data.append((method, fpr, tpr, roc_auc)) |
merged_preds['labels'] = merged_preds['sequence'].apply(lambda x: seq_label_dict[x]) |
merged_preds['labels'] = merged_preds['labels'].apply(lambda x: ",".join([str(y) for y in x])) |
merged_preds['ids'] = merged_preds['sequence'].apply(lambda x: seq_ids_dict[x]) |
merged_preds.drop(columns={'sequence'}).to_csv(f"{results_dir}/CAID_prediction_source_data.csv",index=False) |
merged_tpr_fpr.to_csv(f"{results_dir}/CAID_fpr_tpr_source_data.csv",index=False) |
roc_data = sorted(roc_data, key=lambda x: x[3], reverse=True) |
labels = {method: method for method in method_results} |
if with_rankings: |
for method in labels: |
if method in caid2_model_rankings: |
labels[method] = f"{caid2_model_rankings[method]}. {method}" |
for method, fpr, tpr, roc_auc in roc_data: |
if method=='ESM-2-650M' and path_to_esm_results is not None: |
plt.plot(fpr, tpr, color=color_map[method], lw=2, linestyle='--', label=f'{labels[method]} ({roc_auc:0.3f})') |
elif method == model_alias: |
plt.plot(fpr, tpr, color=color_map[method], lw=2, label=f'{labels[method]} ({roc_auc:0.3f})') |
else: |
plt.plot(fpr, tpr, color=color_map[method], lw=1, alpha=0.7, label=f'{labels[method]} ({roc_auc:0.3f})') |
plt.xlim([0.0, 1.0]) |
plt.ylim([0.0, 1.05]) |
plt.xticks(fontsize=20) |
plt.yticks(fontsize=20) |
plt.plot([0, 1], [0, 1], color='darkgrey', lw=2, linestyle='--') |
plt.xlabel('False Positive Rate', fontsize=22) |
plt.ylabel('True Positive Rate', fontsize=22) |
plt.title('CAID2 Disorder NOX Dataset: ROC Curve', fontsize=22) |
handles, labels = plt.gca().get_legend_handles_labels() |
legend = plt.legend(handles, labels, loc="center left", bbox_to_anchor=(1.1, 0.5), fontsize=16) |
for text in legend.get_texts(): |
if model_alias in text.get_text(): |
text.set_fontweight('bold') |
elif (path_to_esm_results is not None) and "ESM-2-650M" in text.get_text(): |
text.set_fontweight('bold') |
plt.tight_layout() |
figpath = f'{results_dir}/CAID2_{model_alias}_AUROC_curve.png' |
if path_to_esm_results is not None: |
figpath = f'{results_dir}/CAID2_{model_alias}_with_ESM_AUROC_curve.png' |
plt.savefig(figpath) |
def plot_disorder_content_scatter(train_labels, test_labels, benchmark_labels, savepath='splits/disorder_content_scatter.png'): |
""" |
Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels. |
Each labels vector should have ['11110000','0001110',...] format. |
""" |
train_lengths = [] |
train_frac_disorder = [] |
for vec in train_labels: |
veclist = [int(x) for x in vec] |
train_lengths.append(len(veclist)) |
train_frac_disorder.append(sum(veclist)/len(veclist)) |
test_lengths = [] |
test_frac_disorder = [] |
for vec in test_labels: |
veclist = [int(x) for x in vec] |
test_lengths.append(len(veclist)) |
test_frac_disorder.append(sum(veclist)/len(veclist)) |
benchmark_lengths = [] |
benchmark_frac_disorder = [] |
for vec in benchmark_labels: |
veclist = [int(x) for x in vec] |
benchmark_lengths.append(len(veclist)) |
benchmark_frac_disorder.append(sum(veclist)/len(veclist)) |
set_font() |
color_map = { |
'train': '#0072B2', |
'test': '#E69F00', |
'fusion': 'purple' |
} |
fig, ax = plt.subplots(figsize=(10, 6)) |
ax.scatter(train_lengths, train_frac_disorder, color=color_map['train'], label='Train', alpha=0.7) |
ax.scatter(test_lengths, test_frac_disorder, color=color_map['test'], label='Test', alpha=0.7) |
ax.scatter(benchmark_lengths, benchmark_frac_disorder, color=color_map['fusion'], label='Fusion', alpha=0.7) |
ax.set_xlabel('Length') |
ax.set_ylabel('Fraction of Disorder') |
ax.set_title('Length vs. Fraction of Disorder for Train, Test, and Benchmark Datasets') |
ax.legend() |
plt.tight_layout() |
plt.savefig(savepath) |
def plot_disorder_content_hist(labels, ids, title="data", color="black", savepath='splits/disorder_content_histograms.png'): |
""" |
Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels. |
Each labels vector should have ['11110000','0001110',...] format. |
""" |
set_font() |
lengths = [] |
frac_disorder = [] |
for vec in labels: |
veclist = [int(x) for x in vec] |
lengths.append(len(veclist)) |
frac_disorder.append(100*sum(veclist)/len(veclist)) |
source_data = pd.DataFrame(data={ |
'ID': ids, |
'Percent_Disordered': frac_disorder |
}) |
source_data['Percent_Disordered'] = source_data['Percent_Disordered'].round(3) |
source_data.to_csv(savepath.replace(".png","_source_data.csv"),index=False) |
fig, ax = plt.subplots(1, 1, figsize=(20, 12)) |
title_fontsize = 70 |
axislabel_fontsize = 70 |
tick_fontsize = 50 |
ax.hist(frac_disorder, bins=20, color=color, alpha=0.7) |
ax.set_title(title, fontsize=title_fontsize) |
ax.set_xlabel('% Disordered', fontsize=axislabel_fontsize) |
ax.set_ylabel('Count', fontsize=axislabel_fontsize) |
ax.grid(True) |
ax.set_axisbelow(True) |
ax.tick_params(axis='both', which='major', labelsize=tick_fontsize) |
mean_coverage = np.mean(frac_disorder) |
median_coverage = np.median(frac_disorder) |
ax.axvline(mean_coverage, color='black', linestyle='--', linewidth=2, label=f'Mean: {mean_coverage:.1f}%') |
ax.axvline(median_coverage, color='black', linestyle='-', linewidth=2, label=f'Median: {median_coverage:.1f}%') |
ax.legend(fontsize=50, title_fontsize=50) |
plt.tight_layout() |
plt.savefig(savepath) |
def plot_group_disorder_content_hist(train_labels, test_labels, benchmark_labels, savepath='splits/disorder_content_histograms.png',orient='horizontal'): |
""" |
Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels. |
Each labels vector should have ['11110000','0001110',...] format. |
""" |
train_lengths = [] |
train_frac_disorder = [] |
for vec in train_labels: |
veclist = [int(x) for x in vec] |
train_lengths.append(len(veclist)) |
train_frac_disorder.append(sum(veclist)/len(veclist)) |
test_lengths = [] |
test_frac_disorder = [] |
for vec in test_labels: |
veclist = [int(x) for x in vec] |
test_lengths.append(len(veclist)) |
test_frac_disorder.append(sum(veclist)/len(veclist)) |
benchmark_lengths = [] |
benchmark_frac_disorder = [] |
for vec in benchmark_labels: |
veclist = [int(x) for x in vec] |
benchmark_lengths.append(len(veclist)) |
benchmark_frac_disorder.append(sum(veclist)/len(veclist)) |
set_font() |
color_map = { |
'train': '#0072B2', |
'test': '#E69F00', |
'fusion': 'mediumpurple' |
} |
if orient=='horizontal': |
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=False) |
if orient=='vertical': |
fig, axes = plt.subplots(3, 1, figsize=(5, 15), sharey=False) |
title_fontsize = 26 |
axislabel_fontsize = 26 |
tick_fontsize = 16 |
axes[0].hist(train_frac_disorder, bins=20, color=color_map['train'], alpha=0.7) |
axes[0].set_title('CAID2 Train', fontsize=title_fontsize) |
if orient=="horizontal": |
axes[0].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize) |
axes[0].set_ylabel('Frequency', fontsize=axislabel_fontsize) |
axes[0].grid(True) |
axes[0].set_axisbelow(True) |
axes[0].tick_params(axis='both', which='major', labelsize=tick_fontsize) |
axes[1].hist(test_frac_disorder, bins=20, color=color_map['test'], alpha=0.7) |
axes[1].set_title('CAID2 Test',fontsize=title_fontsize) |
if orient=="horizontal": |
axes[1].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize) |
if orient=="vertical": |
axes[1].set_ylabel('Frequency', fontsize=axislabel_fontsize) |
axes[1].grid(True) |
axes[1].set_axisbelow(True) |
axes[1].tick_params(axis='both', which='major', labelsize=tick_fontsize) |
axes[2].hist(benchmark_frac_disorder, bins=20, color=color_map['fusion'], alpha=0.7) |
axes[2].set_title('Fusion Oncoproteins',fontsize=title_fontsize) |
axes[2].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize) |
if orient=="vertical": |
axes[2].set_ylabel('Frequency', fontsize=axislabel_fontsize) |
axes[2].grid(True) |
axes[2].set_axisbelow(True) |
axes[2].tick_params(axis='both', which='major', labelsize=tick_fontsize) |
plt.tight_layout() |
plt.savefig(savepath) |
def categorize_plddt(values): |
categories = { |
"<= 50": sum(1 for x in values if x <= 50), |
"50-70": sum(1 for x in values if 50 < x <= 70), |
"70-90": sum(1 for x in values if 70 < x <= 90), |
"> 90": sum(1 for x in values if x > 90) |
} |
return categories |
def plot_fusion_sequence_pLDDT_left_to_right(fusion_structure_data, fusiongene, save_path=''): |
""" |
Plot each amino acid in the sequence as a separate colored bar based on pLDDT values. |
""" |
set_font() |
df_of_interest = fusion_structure_data[fusion_structure_data['FusionGene'] == fusiongene].copy() |
df_of_interest['Fusion_AA_pLDDTs'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [float(i) for i in x.split(',')]) |
df_of_interest['Label'] = df_of_interest['Fusion_Length'].astype(str) + 'AAs' |
df_of_interest = df_of_interest.sort_values(by='Fusion_Length', ascending=True).reset_index(drop=True) |
category_colors = {"<= 50": "#f27842", "50-70": "#f8d514", "70-90": "#60c1e8", "> 90": "#004ecb"} |
def get_color(pLDDT): |
if pLDDT > 90: |
return category_colors["> 90"] |
elif pLDDT > 70: |
return category_colors["70-90"] |
elif pLDDT > 50: |
return category_colors["50-70"] |
else: |
return category_colors["<= 50"] |
fig, ax = plt.subplots(figsize=(10, 6)) |
if len(df_of_interest)<3: |
fig, ax = plt.subplots(figsize=(10, 2)) |
average_plddt = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_pLDDT'])) |
df_of_interest['Fusion_AA_colors'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [get_color(plddt) for plddt in x]) |
df_of_interest['Fusion_pLDDT_color'] = df_of_interest['Fusion_pLDDT'].apply(lambda plddt: get_color(plddt)) |
df_of_interest[['FusionGene','seq_id','Fusion_Length','Fusion_pLDDT','Fusion_AA_pLDDTs','Fusion_AA_colors','Fusion_pLDDT_color', |
'top_hg_UniProtID','top_hg_UniProt_isoform','top_hg_UniProt_fus_indices', |
'top_tg_UniProtID','top_tg_UniProt_isoform','top_tg_UniProt_fus_indices']].to_csv(f"{save_path}/plddt_sequence_{fusiongene}_source_data.csv",index=False) |
for idx, row in df_of_interest.iterrows(): |
pLDDT_values = row['Fusion_AA_pLDDTs'] |
colors = [get_color(plddt) for plddt in pLDDT_values] |
ax.bar(range(len(pLDDT_values)), |
[0.7] * len(pLDDT_values), color=colors, edgecolor='none', |
bottom=idx - 0.7 / 2) |
labels = df_of_interest['Label'].tolist() |
for idx, label in enumerate(labels): |
avg_plddt_value = average_plddt[label] |
if avg_plddt_value > 90: |
color = '#004ecb' |
elif avg_plddt_value > 70: |
color = "#60c1e8" |
elif avg_plddt_value > 50: |
color = '#f8d514' |
else: |
color = '#f27842' |
if len(df_of_interest)>10: |
markersize = 10 |
elif len(df_of_interest)>5: |
markersize = 16 |
else: |
markersize=12 |
ax.plot(1.02*max(df_of_interest['Fusion_Length']), |
idx, marker='o', color="black", markersize=markersize, markerfacecolor=color, markeredgewidth=2) |
hg_indices, tg_indices = None, None |
if not(type(df_of_interest['top_hg_UniProt_fus_indices'][idx])==float): |
hg_indices = [int(x) for x in df_of_interest['top_hg_UniProt_fus_indices'][idx].split(',')] |
if not(type(df_of_interest['top_tg_UniProt_fus_indices'][idx])==float): |
tg_indices = [int(x) for x in df_of_interest['top_tg_UniProt_fus_indices'][idx].split(',')] |
print(hg_indices, tg_indices) |
if (hg_indices is not None) and (tg_indices is not None): |
box_start = min(hg_indices[-1],tg_indices[0]) |
box_end = max(hg_indices[-1],tg_indices[0]) |
elif hg_indices is not None: |
box_start, box_end = hg_indices[-1], hg_indices[-1] |
elif tg_indices is not None: |
box_start, box_end = tg_indices[0], tg_indices[0] |
print(f"box indices for structure {idx}, fusion gene {fusiongene}", box_start, box_end) |
rect = patches.Rectangle((box_start, idx - 0.7 / 2), box_end-box_start, 0.7, linewidth=2, edgecolor='black', facecolor='none') |
ax.add_patch(rect) |
ax.set_yticks([]) |
ax.set_yticklabels([]) |
ax.set_ylim(-0.5, len(df_of_interest) - 0.5) |
ax.set_xlabel("Amino Acid Sequence (ordered)", fontsize=14) |
ax.set_xlim(left=0) |
ax.set_xlabel("Amino Acid Sequence (ordered)", fontsize=14) |
ax.tick_params(axis='x', labelsize=30) |
plt.title(f"{fusiongene} pLDDT Distribution by Amino Acid Sequence", fontsize=16) |
plt.tight_layout() |
fusiongene_savename = fusiongene.replace("::","-") |
plt.savefig(f"{save_path}/plddt_sequence_{fusiongene_savename}.png", dpi=300) |
plt.show() |
def plot_favorite_fusion_pLDDT_distribution(fusion_structure_data, fusiongene, save_path=''): |
""" |
Make a stacked bar chart of the pLDDT distribution |
""" |
set_font() |
df_of_interest = fusion_structure_data[fusion_structure_data['FusionGene'] == fusiongene].copy() |
df_of_interest['Fusion_AA_pLDDTs'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [float(i) for i in x.split(',')]) |
df_of_interest['Label'] = df_of_interest['Fusion_Length'].astype(str) + 'AAs' |
df_of_interest = df_of_interest.sort_values(by='Fusion_Length', ascending=True).reset_index(drop=True) |
data_dict = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_AA_pLDDTs'])) |
average_plddt = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_pLDDT'])) |
categorized_data = {structure: categorize_plddt(plddt_values) for structure, plddt_values in data_dict.items()} |
labels = list(categorized_data.keys()) |
categories = ["<= 50", "50-70", "70-90", "> 90"] |
counts = {cat: [categorized_data[structure][cat] for structure in labels] for cat in categories} |
category_colors = {"<= 50": "#f27842", "50-70": "#f8d514", "70-90": "#60c1e8", "> 90": "#004ecb"} |
categorized_data = {structure: categorize_plddt(plddt_values) for structure, plddt_values in data_dict.items()} |
labels = list(categorized_data.keys()) |
counts = {cat: [categorized_data[structure][cat] for structure in labels] for cat in categories} |
fig, ax = plt.subplots(figsize=(10, 6)) |
if len(data_dict)<3: |
fig, ax = plt.subplots(figsize=(10, 2)) |
bottom = np.zeros(len(labels)) |
for cat in categories: |
ax.barh(labels, counts[cat], label=cat, color=category_colors[cat], left=bottom) |
bottom += counts[cat] |
for idx, label in enumerate(labels): |
avg_plddt_value = average_plddt[label] |
if avg_plddt_value > 90: |
color = '#004ecb' |
elif avg_plddt_value > 70: |
color = "#60c1e8" |
elif avg_plddt_value > 50: |
color = '#f8d514' |
else: |
color = '#f27842' |
if len(df_of_interest)>10: |
markersize = 10 |
elif len(df_of_interest)>5: |
markersize = 16 |
else: |
markersize=12 |
ax.plot(bottom[idx] + .02*max(df_of_interest['Fusion_Length']), idx, marker='s', color="black", markersize=markersize, markerfacecolor=color, markeredgewidth=2) |
ax.tick_params(axis='x', labelsize=30) |
ax.tick_params(axis='y', left=False, labelleft=False) |
plt.tight_layout() |
fusiongene_savename = fusiongene.replace("::","-") |
plt.savefig(f"{save_path}/plddt_dist_{fusiongene_savename}.png",dpi=300) |
def make_all_favorite_fusion_pLDDT_plots(favorite_fusions,left_to_right=True): |
fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv') |
swissprot_top_alignments = pd.read_csv("../../data/blast/blast_outputs/swissprot_top_alignments.csv") |
fuson_db = pd.read_csv("../../data/fuson_db.csv") |
seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id'])) |
fusion_structure_data['seq_id'] = fusion_structure_data['Fusion_Seq'].map(seq_id_dict) |
fusion_structure_data = pd.merge( |
fusion_structure_data, |
swissprot_top_alignments, |
on="seq_id", |
how="left" |
) |
for x in favorite_fusions: |
if left_to_right: |
plot_fusion_sequence_pLDDT_left_to_right(fusion_structure_data, x, save_path='processed_data/figures/fusion_disorder') |
else: |
plot_favorite_fusion_pLDDT_distribution(fusion_structure_data, x, save_path='processed_data/figures/fusion_disorder') |
def prep_data_for_ht_disorder_comparison(): |
ht_structure_data = pd.read_csv('processed_data/fusionpdb/heads_tails_structural_data.csv') |
fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv') |
fusion_heads_and_tails = pd.read_csv('processed_data/fusionpdb/fusion_heads_and_tails.csv') |
all_hts_with_structures = ht_structure_data['UniProtID'].unique().tolist() |
fuson_ht_db = pd.read_csv('../../data/blast/fuson_ht_db.csv')[['seq_id','aa_seq','fusiongenes','hgUniProt','tgUniProt']] |
merge = pd.merge( |
fuson_ht_db.rename(columns={'aa_seq':'Fusion_Seq'}), |
fusion_structure_data[['FusionGID', 'Fusion_Seq','Fusion_pLDDT','Fusion_AA_pLDDTs']], |
on='Fusion_Seq', |
how='right' |
) |
merge['hgUniProt'] = merge['hgUniProt'].apply(lambda x: x.split(',')) |
merge['tgUniProt'] = merge['tgUniProt'].apply(lambda x: x.split(',')) |
merge = merge.explode('hgUniProt') |
merge = merge.explode('tgUniProt') |
merge = merge.loc[ |
merge['hgUniProt'].isin(all_hts_with_structures) & |
merge['tgUniProt'].isin(all_hts_with_structures) |
].reset_index(drop=True) |
merge = pd.merge( |
merge, |
ht_structure_data.rename(columns= |
{'UniProtID':'hgUniProt', |
'Avg pLDDT': 'hg_pLDDT', |
'All pLDDTs': 'hg_AA_pLDDTs', |
'Seq': 'hg_seq'}), |
on='hgUniProt', |
how='inner' |
) |
merge = pd.merge( |
merge, |
ht_structure_data.rename(columns= |
{'UniProtID':'tgUniProt', |
'Avg pLDDT': 'tg_pLDDT', |
'All pLDDTs': 'tg_AA_pLDDTs', |
'Seq': 'tg_seq'}), |
on='tgUniProt', |
how='inner' |
) |
merge = merge.loc[merge['hg_AA_pLDDTs'].notna()] |
merge = merge.loc[merge['tg_AA_pLDDTs'].notna()].reset_index(drop=True) |
merge['hg_label'] = merge['hg_AA_pLDDTs'].apply(lambda x: x.split(',')) |
merge['hg_label'] = merge['hg_label'].apply(lambda x: [float(y) for y in x]) |
merge['hg_label'] = merge['hg_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x]) |
merge['hg_label'] = merge['hg_label'].apply(lambda x: ''.join(x)) |
merge['tg_label'] = merge['tg_AA_pLDDTs'].apply(lambda x: x.split(',')) |
merge['tg_label'] = merge['tg_label'].apply(lambda x: [float(y) for y in x]) |
merge['tg_label'] = merge['tg_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x]) |
merge['tg_label'] = merge['tg_label'].apply(lambda x: ''.join(x)) |
merge['fusion_label'] = merge['Fusion_AA_pLDDTs'].apply(lambda x: x.split(',')) |
merge['fusion_label'] = merge['fusion_label'].apply(lambda x: [float(y) for y in x]) |
merge['fusion_label'] = merge['fusion_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x]) |
merge['fusion_label'] = merge['fusion_label'].apply(lambda x: ''.join(x)) |
return merge |
def apply_plddt_thresh(y): |
if y < 68.8: |
return '1' |
else: |
return '0' |
def plot_fusion_stats_boxplots(data, save_path="fusion_disorder_boxplots.png"): |
set_font() |
plt.figure(figsize=(6, 5)) |
box = plt.boxplot([data[col].dropna() for col in data.columns], labels=data.columns, patch_artist=True) |
for patch in box['boxes']: |
patch.set_facecolor('#ff68b4') |
patch.set_edgecolor('#ff68b4') |
for median in box['medians']: |
median.set_color('black') |
plt.title(f"Per-Residue Disorder (n={len(data)})",fontsize=22) |
plt.xticks(rotation=20,fontsize=22) |
plt.yticks(fontsize=22) |
plt.tight_layout() |
plt.show() |
plt.savefig(save_path,dpi=300) |
def plot_fusion_frac_disorder_r2(actual_values, predicted_values, save_path="fusion_pred_disorder_r2.png"): |
set_font() |
plt.figure(figsize=(6, 6)) |
r2 = r2_score(actual_values, predicted_values) |
plt.scatter(actual_values, predicted_values, alpha=0.5, label=f"Predictions", color="#ff68b4") |
plt.plot([min(actual_values), max(actual_values)], [min(actual_values), max(actual_values)], 'k--', label='Ideal Fit') |
plt.text(0, 92, f"$R^2$={r2:.2f}", fontsize=32) |
plt.xlabel(f'AlphaFold-pLDDT',size=32) |
plt.ylabel(f'FusOn-pLM-Diso',size=32) |
plt.title(f"% Disordered (n={len(actual_values)})",size=32) |
plt.xticks(fontsize=24) |
plt.yticks(fontsize=24) |
plt.legend(prop={'size': 16}) |
plt.tight_layout() |
plt.show() |
plt.savefig(save_path, dpi=300) |
def main(): |
set_font() |
output_dir = "results/final" |
seq_label_dict = pd.read_csv('splits/test_df.csv') |
seq_ids_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['IDs'])) |
seq_label_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['Label'])) |
best_caid_model_results = pd.read_csv(f"{output_dir}/best_caid_model_results.csv") |
make_auroc_curve(results_dir=output_dir, |
seq_label_dict=seq_label_dict, |
seq_ids_dict=seq_ids_dict, |
path_to_results_of_interest="trained_models/fuson_plm/best/caid_hyperparam_screen_test_probs.csv", |
model_alias="FusOn-pLM", |
path_to_esm_results="trained_models/esm2_t33_650M_UR50D/best/caid_hyperparam_screen_test_probs.csv", |
with_rankings=True) |
caid2_test_data = pd.read_csv(f"splits/splits.csv") |
caid2_test_data = caid2_test_data.loc[caid2_test_data['Split']=='Test'] |
caid2_test_labels = caid2_test_data['Label'].tolist() |
caid2_test_ids = caid2_test_data['IDs'].tolist() |
fusion_ht_data = prep_data_for_ht_disorder_comparison() |
os.makedirs("processed_data/figures",exist_ok=True) |
head_data = fusion_ht_data.drop_duplicates(['hg_seq']).reset_index(drop=True) |
head_labels = head_data['hg_label'].tolist() |
head_ids = head_data['hgUniProt'].tolist() |
tail_data = fusion_ht_data.drop_duplicates(['tg_seq']).reset_index(drop=True) |
tail_labels = tail_data['tg_label'].tolist() |
tail_ids = tail_data['tgUniProt'].tolist() |
fusion_data = fusion_ht_data.drop_duplicates(['Fusion_Seq']).reset_index(drop=True) |
fusion_labels = fusion_data['fusion_label'].tolist() |
fusion_ids = fusion_data['seq_id'].tolist() |
plt.rc('text', usetex=False) |
math_part = r"$n$" |
os.makedirs("processed_data/figures/histograms",exist_ok=True) |
plot_disorder_content_hist(caid2_test_labels, caid2_test_ids, title=f"CAID2 Disorder-NOX ({math_part}={len(caid2_test_labels):,})", color="black", savepath='processed_data/figures/histograms/disorder_nox_histogram.png') |
plot_disorder_content_hist(head_labels, head_ids, title=f"Head Proteins ({math_part}={len(head_labels):,})", color="#df8385", savepath='processed_data/figures/histograms/heads_histogram.png') |
plot_disorder_content_hist(tail_labels, tail_ids, title=f"Tail Proteins ({math_part}={len(tail_labels):,})", color="#6ea4da", savepath='processed_data/figures/histograms/tails_histogram.png') |
plot_disorder_content_hist(fusion_labels, fusion_ids, title=f"Fusion Oncoproteins ({math_part}={len(fusion_labels):,})", color="mediumpurple", savepath='processed_data/figures/histograms/fusions_histogram.png') |
os.makedirs("processed_data/figures/fusion_disorder",exist_ok=True) |
make_all_favorite_fusion_pLDDT_plots([ |
"EWSR1::FLI1", |
"PAX3::FOXO1", |
"EML4::ALK", |
"SS18::SSX1"], |
left_to_right=True) |
if __name__ == "__main__": |
main() |