import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import numpy as np import os import matplotlib.colors as mcolors import matplotlib.patches as mpatches from matplotlib import font_manager import matplotlib.patches as patches from sklearn.metrics import roc_curve, auc, r2_score from fuson_plm.utils.visualizing import set_font global caid2_winners, caid2_model_rankings caid2_winners = pd.DataFrame(data= { 'Model Name': ['Dispredict3','flDPnn2','flDPnn','flDPlr','flDPlr2','DisoPred', 'IDP-Fusion','ESpritz-D','DeepIDP-2L','disomine','DISOPRED3-diso','IUPred3', 'AlphaFold-rsa','AlphaFold-pLDDT'], # do the top 6 models, and IUPred because it's well-known 'AUROC': [0.838,0.836,0.833,0.827,0.821,0.821, 0.818,0.802,0.800,0.797,0.692,0.755,0.747,0.695], }) caid2_winners['Model Type'] = ['caid2_competition']*len(caid2_winners) caid2_winners['Model Epoch'] = [np.nan]*len(caid2_winners) caid2_model_rankings = { 'Dispredict3': 1, 'flDPnn2': 2, 'flDPnn': 3, 'flDPlr': 4, 'flDPlr2': 5, 'DisoPred': 6, 'IDP-Fusion': 7, 'ESpritz-D': 8, 'DeepIDP-2L': 9, 'disomine': 10, 'DISOPRED3-diso': 35, 'IUPred3': 21, 'AlphaFold-rsa': 24, 'AlphaFold-pLDDT': 34 } # Method for lengthening the model name def lengthen_model_name(row): model_type = row['Model Type'] name = row['Model Name'] epoch = row['Model Epoch'] if 'esm' in name: return name if 'puncta' in name: return name if model_type=='caid2_competition': return name return f'{name}_e{epoch}' # Method for shortening the model name for display def shorten_model_name(row): model_type = row['Model Type'] name = row['Model Name'] epoch = row['Model Epoch'] if 'esm' in name: return 'ESM-2-650M' if model_type=='caid2_competition': return name if 'snp_' in name: prob_type = 'snp' elif 'uniform_' in name: prob_type = 'uni' layers = name.split('layers')[0].split('_')[-1] maskrate = name.split('mask')[1].split('-', 1)[0] kqv_tag = name.split('layers_')[1].split('_')[0] dt = name.split('mask')[1].split('-', 1)[1] return f'{prob_type}_{layers}L_{kqv_tag}_mask{maskrate}_{dt}_e{epoch}' def make_heatmap(df, results_dir='.', gold_standard_model_name="esm2_t33_650M_UR50D",split="test",thresh=None,ax=None): # Set font to Ubuntu set_font() # Declare columns to compare: metrics columns_to_compare = ['AUROC'] # Define the literature-reported values for CAID competition winners - only IF the split is not "benchmark" if not(split=="benchmark"): df = pd.concat([df,caid2_winners]) # Create Short Model Name and Full Model Name columns for later use df['Model Epoch'] = df['Model Epoch'].apply(lambda x: str(int(x)) if not(np.isnan(x)) else '') df['Short Model Name'] = df.apply(lambda row: shorten_model_name(row),axis=1) df['Full Model Name'] = df.apply(lambda row: lengthen_model_name(row), axis=1) # Isolate gold standard row for later comparison gold_standard = df[df['Full Model Name'] == gold_standard_model_name].reset_index(drop=True).iloc[0] gold_standard_short_model_name = df[df['Full Model Name'] == gold_standard_model_name]['Short Model Name'].item() # Create a new dataframe for the heatmap; sort by model type and place gold standard on top heatmap_data = df[['Model Type','Short Model Name','Full Model Name'] + columns_to_compare].copy() heatmap_data['is_gold_standard'] = (heatmap_data['Full Model Name'] == gold_standard_model_name).astype(int) heatmap_data = heatmap_data.sort_values(by=['is_gold_standard','Model Type','AUROC'], ascending=[False,True,False]).reset_index(drop=True).drop(columns=['is_gold_standard']) # Save the original values before calculating differences so we can use them for annotation original_values = heatmap_data[columns_to_compare].copy() # Calculate differences from the gold standard for col in columns_to_compare: heatmap_data[col] = heatmap_data[col] - gold_standard[col] # Create a color map where values equal to 0 are white, above are red, and below are blue cmap = sns.color_palette("coolwarm", as_cmap=True) # other option is diverging_palette(220, 20, as_cmap=True) ### Make the plot # can plot on a bigger plot, or make it an individual plot if ax is None: tallsize = max(8, 8 +.25*(len(heatmap_data)-26)) fig, ax = plt.subplots(1, 1, figsize=(8, tallsize), dpi=300) # Plot the heatmap with original values as annotations hm = sns.heatmap(heatmap_data.set_index('Short Model Name').drop(columns=['Model Type','Full Model Name']), annot=False, fmt='', cmap=cmap, center=0, cbar_kws={'label': 'Difference from Gold Standard'}) # Explicitly set tick labels to prevent them from being messed up ax.set_yticklabels(heatmap_data['Short Model Name'], rotation=0, fontsize=12) # Add padding to the y-axis label ax.set_ylabel("Short Model Name", labelpad=20) # Increase the labelpad value to add more padding # Bold any values values that exceed the gold standard for i in range(original_values.shape[0]): for j in range(original_values.shape[1]): value = original_values.iloc[i, j] if value > gold_standard[columns_to_compare[j]]: ax.text(j + 0.5, i + 0.5, f'{value:.3f}', ha='center', va='center', fontweight='bold', color='black') else: ax.text(j + 0.5, i + 0.5, f'{value:.3f}', ha='center', va='center', color='black') # Add horizontal lines between different model types model_type_series = heatmap_data['Model Type'].values last_index = 0 labels_positions = [] # To store the positions for labels for i in range(1, len(model_type_series)): if model_type_series[i] != model_type_series[i - 1]: hm.axhline(i, color='white', linewidth=8) # Draw a thick white line between groups labels_positions.append((last_index + i) / 2) # Store the midpoint for labeling last_index = i # Add label for the last group labels_positions.append((last_index + len(model_type_series)) / 2) # Italic and bold models that win AUROC; apply yellow coloring to gold standard model for ytick, model_name in enumerate(heatmap_data['Short Model Name']): if model_name == gold_standard_short_model_name: # color yellow label = ax.get_yticklabels()[ytick] #label.set_color('gold') label.set_bbox(dict(facecolor='gold', alpha=0.5, edgecolor='gold')) if model_name != gold_standard_short_model_name: auroc_value = original_values.loc[ytick, 'AUROC'] # Apply bold and italic for wins on either AUROC or F1 Score if (auroc_value > gold_standard['AUROC']): label = ax.get_yticklabels()[ytick] #label.set_style('italic') #label.set_weight('bold') label.set_bbox(dict(facecolor='red', alpha=0.3, edgecolor='red')) # Make legend gold_patch = mpatches.Patch(color='gold', alpha=0.5, label='Gold Standard') red_patch = mpatches.Patch(color='red', alpha=0.5, label='Winner') plt.legend(handles=[gold_patch, red_patch], loc='best', bbox_to_anchor=(0, 0)) # You can change loc to position the legend split_fname_dict = { "testing": "CAID2_test", "training": "CAID2_train", "benchmark": "FusionPDB_pLDDT_disorder" } split_title_dict = { "testing": "CAID-2 Disorder Prediction", "training": "CAID-2 Disorder Prediction", "benchmark": "FusionPDB_pLDDT Disorder Prediction" } ax.set_title(split_title_dict[split]) # Rotate the color bar label cbar = hm.collections[0].colorbar cbar.ax.yaxis.set_label_position('right') cbar.ax.yaxis.set_ticks_position('right') cbar.set_label('Difference from Gold Standard', rotation=270, labelpad=20) # Rotate 270 degrees and add some padding # Set tight layout using fig fig.tight_layout(rect=[0, 0, 0.95, 1]) # Add extra padding on the right side to fit the label plt.savefig(f"{results_dir}/{split_fname_dict[split]}_heatmap_vs_{gold_standard_model_name}.png") # Plot AUROC curve of ONE model of interest on its fusion pdb performance def make_benchmark_auroc_curve(results_dir='.', seq_label_dict=None, path_to_results_of_interest='', model_alias=None): # Isolate the information for the model we'll be plotting benchmark_model = path_to_results_of_interest.split('trained_models/')[1].split('/') benchmark_model_type = benchmark_model[0] benchmark_model_epoch = np.nan benchmark_model_hyperparams = None if len(benchmark_model)==5: benchmark_model_name = benchmark_model[1] benchmark_model_epoch = benchmark_model[2].split('epoch')[1] benchmark_model_hyperparams = benchmark_model[3] else: benchmark_model_name = benchmark_model[0] benchmark_model_hyperparams = benchmark_model[1] benchmark_model_info = pd.DataFrame(data={ 'Model Type': [benchmark_model_type], 'Model Name': [benchmark_model_name], 'Model Epoch': [benchmark_model_epoch] }) if model_alias is None: model_alias = benchmark_model_info.apply(lambda row: shorten_model_name(row),axis=1).iloc[0] color_map = { model_alias: 'black' } method_results = {model_alias: path_to_results_of_interest} method_results = {k:v for k,v in method_results.items() if v not in [None, '']} set_font() plt.figure(figsize=(10,6),dpi=300) # To store AUROC values and corresponding labels for sorting roc_data = [] # Read each result file and plot the metrics for method, path in method_results.items(): df = pd.read_csv(path) # columns = prob_1,labels # Extract probabilities and labels prob_1 = ",".join(df['prob_1'].tolist()) df['labels'] = df['sequence'].apply(lambda x: seq_label_dict[x]) labels = "".join(df['labels'].tolist()) prob_1 = [float(x) for x in prob_1.split(",")] labels = [int(x) for x in list(labels)] sequences = "".join(df['sequence'].tolist()) assert len(prob_1)==len(labels)==len(sequences) # Compute ROC curve and ROC area fpr, tpr, thresholds = roc_curve(labels, prob_1) roc_auc = auc(fpr, tpr) # Store data for sorting later roc_data.append((method, fpr, tpr, roc_auc)) # Sort the methods by AUROC values roc_data = sorted(roc_data, key=lambda x: x[3], reverse=True) # Plot sorted ROC curves for method, fpr, tpr, roc_auc in roc_data: if method == model_alias: plt.plot(fpr, tpr, color=color_map[method], lw=2, label=f'{method} ({roc_auc:0.3f})') else: plt.plot(fpr, tpr, color=color_map[method], lw=1, alpha=0.7, label=f'{method} ({roc_auc:0.3f})') # Set other stylistic elements plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.plot([0, 1], [0, 1], color='darkgrey', lw=2, linestyle='--') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver Operating Characteristic (ROC) Curve') # After plotting the ROC curves, customize the legend handles, labels = plt.gca().get_legend_handles_labels() # Create the legend first legend = plt.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5)) # Iterate through the legend's text labels for text in legend.get_texts(): if model_alias in text.get_text(): text.set_fontweight('bold') # Bold the alias model plt.tight_layout() plt.savefig(f'{results_dir}/FusionPDB_pLDDT_disorder_{model_alias}_AUROC_curve.png') # Plot AUROC curve of ONE model of interest with all the CAID models def make_auroc_curve(results_dir='.', seq_label_dict=None, seq_ids_dict=None, path_to_results_of_interest='', model_alias=None, path_to_esm_results=None, with_rankings=False): # Isolate the information for the model we'll be plotting benchmark_model = path_to_results_of_interest.split('trained_models/')[1].split('/') benchmark_model_type = benchmark_model[0] benchmark_model_epoch = np.nan benchmark_model_hyperparams = None if len(benchmark_model)==5: benchmark_model_name = benchmark_model[1] benchmark_model_epoch = benchmark_model[2].split('epoch')[1] benchmark_model_hyperparams = benchmark_model[3] else: benchmark_model_name = benchmark_model[0] benchmark_model_hyperparams = benchmark_model[1] benchmark_model_info = pd.DataFrame(data={ 'Model Type': [benchmark_model_type], 'Model Name': [benchmark_model_name], 'Model Epoch': [benchmark_model_epoch] }) if model_alias is None: model_alias = benchmark_model_info.apply(lambda row: shorten_model_name(row),axis=1).iloc[0] color_map = { 'Dispredict3': '#d62727', #1 'flDPnn2': '#ff7f0f', #2 'flDPnn': '#1f77b4', #3 'flDPlr': '#bcbd21', #4 'flDPlr2': '#16becf', #5 'DisoPred': '#1f77b4', #6 'IDP-Fusion': '#d62727', #7 'ESpritz-D': '#8b564c', #8 'DeepIDP-2L': '#e377c2', #9 'disomine': '#e377c2', #10 'DISOPRED3-diso': '#ff892d', 'IUPred3': '#8b564c', 'AlphaFold-rsa': '#2ba02b', 'AlphaFold-pLDDT': '#ff892d', model_alias: 'black' } method_results = {'Dispredict3': 'processed_data/caid2_competition_results/Dispredict3_CAID-2_Disorder_NOX.csv', 'flDPnn2': 'processed_data/caid2_competition_results/flDPnn2_CAID-2_Disorder_NOX.csv', 'flDPnn': 'processed_data/caid2_competition_results/flDPnn_CAID-2_Disorder_NOX.csv', 'flDPlr': 'processed_data/caid2_competition_results/flDPtr_CAID-2_Disorder_NOX.csv', # name doesn't match but this is what it is in raw download 'flDPlr2': 'processed_data/caid2_competition_results/flDPlr2_CAID-2_Disorder_NOX.csv', 'DisoPred': 'processed_data/caid2_competition_results/DisoPred_CAID-2_Disorder_NOX.csv', 'IDP-Fusion': 'processed_data/caid2_competition_results/IDP-Fusion_CAID-2_Disorder_NOX.csv', 'ESpritz-D': 'processed_data/caid2_competition_results/ESpritz-D_CAID-2_Disorder_NOX.csv', 'DeepIDP-2L': 'processed_data/caid2_competition_results/DeepIDP-2L_CAID-2_Disorder_NOX.csv', 'disomine': 'processed_data/caid2_competition_results/disomine_CAID-2_Disorder_NOX.csv', 'DISOPRED3-diso': 'processed_data/caid2_competition_results/DISOPRED3-diso_CAID-2_Disorder_NOX.csv', 'AlphaFold-rsa': 'processed_data/caid2_competition_results/AlphaFold-rsa_CAID-2_Disorder_NOX.csv', 'AlphaFold-pLDDT': 'processed_data/caid2_competition_results/AlphaFold-disorder_CAID-2_Disorder_NOX.csv', # name doesn't match but this is what it is in raw download 'IUPred3': 'processed_data/caid2_competition_results/IUPred3_CAID-2_Disorder_NOX.csv', model_alias: path_to_results_of_interest } if path_to_esm_results is not None: method_results['ESM-2-650M'] = path_to_esm_results color_map['ESM-2-650M'] = 'black' method_results = {k:v for k,v in method_results.items() if v not in [None, '']} set_font() plt.figure(figsize=(12,6),dpi=300) # To store AUROC values and corresponding labels for sorting merged_preds = pd.DataFrame(data={'sequence':[]}) merged_tpr_fpr = pd.DataFrame(data={'model': [],'fpr':[],'tpr':[]}) roc_data = [] # Read each result file and plot the metrics for method, path in method_results.items(): df = pd.read_csv(path) # columns = prob_1,labels merged_preds = pd.merge(merged_preds, df.rename(columns={'prob_1':f"{method}_prob_1"})[['sequence',f"{method}_prob_1",]], on=['sequence'],how='outer') # Extract probabilities and labels prob_1 = ",".join(df['prob_1'].tolist()) df['labels'] = df['sequence'].apply(lambda x: seq_label_dict[x]) labels = "".join(df['labels'].tolist()) prob_1 = [float(x) for x in prob_1.split(",")] labels = [int(x) for x in list(labels)] sequences = "".join(df['sequence'].tolist()) assert len(prob_1)==len(labels)==len(sequences) # Compute ROC curve and ROC area fpr, tpr, thresholds = roc_curve(labels, prob_1) new_tpr_fpr = pd.DataFrame(data={ 'model': [method]*len(fpr), 'fpr': fpr, 'tpr': tpr }) merged_tpr_fpr = pd.concat([merged_tpr_fpr,new_tpr_fpr]) roc_auc = auc(fpr, tpr) if method==model_alias: path_to_og_metrics = path_to_results_of_interest.rsplit('/',1)[0]+'/caid_hyperparam_screen_test_metrics.csv' og_metrics = pd.read_csv(path_to_og_metrics) roc_auc = og_metrics['AUROC'][0] # Store data for sorting later roc_data.append((method, fpr, tpr, roc_auc)) # Save the merged dataframe as source data merged_preds['labels'] = merged_preds['sequence'].apply(lambda x: seq_label_dict[x]) merged_preds['labels'] = merged_preds['labels'].apply(lambda x: ",".join([str(y) for y in x])) merged_preds['ids'] = merged_preds['sequence'].apply(lambda x: seq_ids_dict[x]) merged_preds.drop(columns={'sequence'}).to_csv(f"{results_dir}/CAID_prediction_source_data.csv",index=False) merged_tpr_fpr.to_csv(f"{results_dir}/CAID_fpr_tpr_source_data.csv",index=False) # Sort the methods by AUROC values roc_data = sorted(roc_data, key=lambda x: x[3], reverse=True) # figure out the labels labels = {method: method for method in method_results} if with_rankings: for method in labels: if method in caid2_model_rankings: labels[method] = f"{caid2_model_rankings[method]}. {method}" # Plot sorted ROC curves for method, fpr, tpr, roc_auc in roc_data: if method=='ESM-2-650M' and path_to_esm_results is not None: plt.plot(fpr, tpr, color=color_map[method], lw=2, linestyle='--', label=f'{labels[method]} ({roc_auc:0.3f})') elif method == model_alias: plt.plot(fpr, tpr, color=color_map[method], lw=2, label=f'{labels[method]} ({roc_auc:0.3f})') else: plt.plot(fpr, tpr, color=color_map[method], lw=1, alpha=0.7, label=f'{labels[method]} ({roc_auc:0.3f})') # Set other stylistic elements plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xticks(fontsize=20) plt.yticks(fontsize=20) plt.plot([0, 1], [0, 1], color='darkgrey', lw=2, linestyle='--') plt.xlabel('False Positive Rate', fontsize=22) plt.ylabel('True Positive Rate', fontsize=22) plt.title('CAID2 Disorder NOX Dataset: ROC Curve', fontsize=22) # After plotting the ROC curves, customize the legend handles, labels = plt.gca().get_legend_handles_labels() # Create the legend first legend = plt.legend(handles, labels, loc="center left", bbox_to_anchor=(1.1, 0.5), fontsize=16) # Iterate through the legend's text labels for text in legend.get_texts(): if model_alias in text.get_text(): text.set_fontweight('bold') # Bold the alias model elif (path_to_esm_results is not None) and "ESM-2-650M" in text.get_text(): text.set_fontweight('bold') # Bold ESM if we're comparing to it plt.tight_layout() figpath = f'{results_dir}/CAID2_{model_alias}_AUROC_curve.png' if path_to_esm_results is not None: figpath = f'{results_dir}/CAID2_{model_alias}_with_ESM_AUROC_curve.png' plt.savefig(figpath) def plot_disorder_content_scatter(train_labels, test_labels, benchmark_labels, savepath='splits/disorder_content_scatter.png'): """ Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels. Each labels vector should have ['11110000','0001110',...] format. """ # Get train disorder distribution train_lengths = [] train_frac_disorder = [] for vec in train_labels: veclist = [int(x) for x in vec] train_lengths.append(len(veclist)) train_frac_disorder.append(sum(veclist)/len(veclist)) # Get test disorder distribution test_lengths = [] test_frac_disorder = [] for vec in test_labels: veclist = [int(x) for x in vec] test_lengths.append(len(veclist)) test_frac_disorder.append(sum(veclist)/len(veclist)) # Get benchmark disorder distribution benchmark_lengths = [] benchmark_frac_disorder = [] for vec in benchmark_labels: veclist = [int(x) for x in vec] benchmark_lengths.append(len(veclist)) benchmark_frac_disorder.append(sum(veclist)/len(veclist)) # make a plot set_font() color_map = { 'train': '#0072B2', 'test': '#E69F00', 'fusion': 'purple' } # Plotting fig, ax = plt.subplots(figsize=(10, 6)) ax.scatter(train_lengths, train_frac_disorder, color=color_map['train'], label='Train', alpha=0.7) ax.scatter(test_lengths, test_frac_disorder, color=color_map['test'], label='Test', alpha=0.7) ax.scatter(benchmark_lengths, benchmark_frac_disorder, color=color_map['fusion'], label='Fusion', alpha=0.7) # Labels and title ax.set_xlabel('Length') ax.set_ylabel('Fraction of Disorder') ax.set_title('Length vs. Fraction of Disorder for Train, Test, and Benchmark Datasets') ax.legend() plt.tight_layout() plt.savefig(savepath) def plot_disorder_content_hist(labels, ids, title="data", color="black", savepath='splits/disorder_content_histograms.png'): """ Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels. Each labels vector should have ['11110000','0001110',...] format. """ set_font() # Get disorder distribution lengths = [] frac_disorder = [] for vec in labels: veclist = [int(x) for x in vec] lengths.append(len(veclist)) frac_disorder.append(100*sum(veclist)/len(veclist)) # make it a percent, i like this better # save the source data source_data = pd.DataFrame(data={ 'ID': ids, 'Percent_Disordered': frac_disorder }) source_data['Percent_Disordered'] = source_data['Percent_Disordered'].round(3) source_data.to_csv(savepath.replace(".png","_source_data.csv"),index=False) fig, ax = plt.subplots(1, 1, figsize=(20, 12)) # Plot histogram for train data title_fontsize = 70 axislabel_fontsize = 70 tick_fontsize = 50 ax.hist(frac_disorder, bins=20, color=color, alpha=0.7) ax.set_title(title, fontsize=title_fontsize) ax.set_xlabel('% Disordered', fontsize=axislabel_fontsize) ax.set_ylabel('Count', fontsize=axislabel_fontsize) ax.grid(True) ax.set_axisbelow(True) ax.tick_params(axis='both', which='major', labelsize=tick_fontsize) # Calculate the mean and median of the percent coverage mean_coverage = np.mean(frac_disorder) median_coverage = np.median(frac_disorder) # Add vertical line for the mean ax.axvline(mean_coverage, color='black', linestyle='--', linewidth=2, label=f'Mean: {mean_coverage:.1f}%') # Add vertical line for the median ax.axvline(median_coverage, color='black', linestyle='-', linewidth=2, label=f'Median: {median_coverage:.1f}%') ax.legend(fontsize=50, title_fontsize=50) plt.tight_layout() plt.savefig(savepath) def plot_group_disorder_content_hist(train_labels, test_labels, benchmark_labels, savepath='splits/disorder_content_histograms.png',orient='horizontal'): """ Compare disorder content between the train, test, and fusion benchmark sets based on the TRUE labels. Each labels vector should have ['11110000','0001110',...] format. """ # Get train disorder distribution train_lengths = [] train_frac_disorder = [] for vec in train_labels: veclist = [int(x) for x in vec] train_lengths.append(len(veclist)) train_frac_disorder.append(sum(veclist)/len(veclist)) # Get test disorder distribution test_lengths = [] test_frac_disorder = [] for vec in test_labels: veclist = [int(x) for x in vec] test_lengths.append(len(veclist)) test_frac_disorder.append(sum(veclist)/len(veclist)) # Get benchmark disorder distribution benchmark_lengths = [] benchmark_frac_disorder = [] for vec in benchmark_labels: veclist = [int(x) for x in vec] benchmark_lengths.append(len(veclist)) benchmark_frac_disorder.append(sum(veclist)/len(veclist)) # make a plot set_font() color_map = { 'train': '#0072B2', 'test': '#E69F00', 'fusion': 'mediumpurple' } # Create a 1x3 subplot (1 row, 3 columns) or 3x1 if orient=='horizontal': fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=False) if orient=='vertical': fig, axes = plt.subplots(3, 1, figsize=(5, 15), sharey=False) # Plot histogram for train data title_fontsize = 26 axislabel_fontsize = 26 tick_fontsize = 16 axes[0].hist(train_frac_disorder, bins=20, color=color_map['train'], alpha=0.7) axes[0].set_title('CAID2 Train', fontsize=title_fontsize) if orient=="horizontal": axes[0].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize) axes[0].set_ylabel('Frequency', fontsize=axislabel_fontsize) axes[0].grid(True) axes[0].set_axisbelow(True) axes[0].tick_params(axis='both', which='major', labelsize=tick_fontsize) # Plot histogram for test data axes[1].hist(test_frac_disorder, bins=20, color=color_map['test'], alpha=0.7) axes[1].set_title('CAID2 Test',fontsize=title_fontsize) if orient=="horizontal": axes[1].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize) if orient=="vertical": axes[1].set_ylabel('Frequency', fontsize=axislabel_fontsize) axes[1].grid(True) axes[1].set_axisbelow(True) axes[1].tick_params(axis='both', which='major', labelsize=tick_fontsize) # Plot histogram for benchmark (fusion) data axes[2].hist(benchmark_frac_disorder, bins=20, color=color_map['fusion'], alpha=0.7) axes[2].set_title('Fusion Oncoproteins',fontsize=title_fontsize) axes[2].set_xlabel('Fraction of Disorder', fontsize=axislabel_fontsize) if orient=="vertical": axes[2].set_ylabel('Frequency', fontsize=axislabel_fontsize) axes[2].grid(True) axes[2].set_axisbelow(True) axes[2].tick_params(axis='both', which='major', labelsize=tick_fontsize) plt.tight_layout() plt.savefig(savepath) def categorize_plddt(values): categories = { "<= 50": sum(1 for x in values if x <= 50), "50-70": sum(1 for x in values if 50 < x <= 70), "70-90": sum(1 for x in values if 70 < x <= 90), "> 90": sum(1 for x in values if x > 90) } return categories def plot_fusion_sequence_pLDDT_left_to_right(fusion_structure_data, fusiongene, save_path=''): """ Plot each amino acid in the sequence as a separate colored bar based on pLDDT values. """ set_font() # Filter for specific fusion data and preprocess df_of_interest = fusion_structure_data[fusion_structure_data['FusionGene'] == fusiongene].copy() df_of_interest['Fusion_AA_pLDDTs'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [float(i) for i in x.split(',')]) df_of_interest['Label'] = df_of_interest['Fusion_Length'].astype(str) + 'AAs' # Sort data by Fusion_Length df_of_interest = df_of_interest.sort_values(by='Fusion_Length', ascending=True).reset_index(drop=True) # Define colors for each pLDDT range category_colors = {"<= 50": "#f27842", "50-70": "#f8d514", "70-90": "#60c1e8", "> 90": "#004ecb"} # Helper function to get color based on pLDDT def get_color(pLDDT): if pLDDT > 90: return category_colors["> 90"] elif pLDDT > 70: return category_colors["70-90"] elif pLDDT > 50: return category_colors["50-70"] else: return category_colors["<= 50"] # Start plotting each sequence with colored bars fig, ax = plt.subplots(figsize=(10, 6)) if len(df_of_interest)<3: fig, ax = plt.subplots(figsize=(10, 2)) average_plddt = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_pLDDT'])) df_of_interest['Fusion_AA_colors'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [get_color(plddt) for plddt in x]) df_of_interest['Fusion_pLDDT_color'] = df_of_interest['Fusion_pLDDT'].apply(lambda plddt: get_color(plddt)) # just save the columns needed for the plot df_of_interest[['FusionGene','seq_id','Fusion_Length','Fusion_pLDDT','Fusion_AA_pLDDTs','Fusion_AA_colors','Fusion_pLDDT_color', 'top_hg_UniProtID','top_hg_UniProt_isoform','top_hg_UniProt_fus_indices', 'top_tg_UniProtID','top_tg_UniProt_isoform','top_tg_UniProt_fus_indices']].to_csv(f"{save_path}/plddt_sequence_{fusiongene}_source_data.csv",index=False) for idx, row in df_of_interest.iterrows(): pLDDT_values = row['Fusion_AA_pLDDTs'] colors = [get_color(plddt) for plddt in pLDDT_values] # Plot each amino acid in the sequence with the respective color ax.bar(range(len(pLDDT_values)), [0.7] * len(pLDDT_values), color=colors, edgecolor='none', bottom=idx - 0.7 / 2) # Centering each row at idx labels = df_of_interest['Label'].tolist() # Annotate each bar with the Fusion_pLDDT value on the right, colored by PLDDT category for idx, label in enumerate(labels): avg_plddt_value = average_plddt[label] # Determine color based on the PLDDT category if avg_plddt_value > 90: color = '#004ecb' elif avg_plddt_value > 70: color = "#60c1e8" elif avg_plddt_value > 50: color = '#f8d514' else: color = '#f27842' # Annotate with the determined color if len(df_of_interest)>10: markersize = 10 elif len(df_of_interest)>5: markersize = 16 else: markersize=12 ax.plot(1.02*max(df_of_interest['Fusion_Length']), idx, marker='o', color="black", markersize=markersize, markerfacecolor=color, markeredgewidth=2) # Add breakpoint box - make sure we actually HAVE one of each hg_indices, tg_indices = None, None if not(type(df_of_interest['top_hg_UniProt_fus_indices'][idx])==float): hg_indices = [int(x) for x in df_of_interest['top_hg_UniProt_fus_indices'][idx].split(',')] if not(type(df_of_interest['top_tg_UniProt_fus_indices'][idx])==float): tg_indices = [int(x) for x in df_of_interest['top_tg_UniProt_fus_indices'][idx].split(',')] print(hg_indices, tg_indices) if (hg_indices is not None) and (tg_indices is not None): box_start = min(hg_indices[-1],tg_indices[0]) box_end = max(hg_indices[-1],tg_indices[0]) elif hg_indices is not None: box_start, box_end = hg_indices[-1], hg_indices[-1] elif tg_indices is not None: box_start, box_end = tg_indices[0], tg_indices[0] print(f"box indices for structure {idx}, fusion gene {fusiongene}", box_start, box_end) # Plot the rectangle, making it slightly larger than the rest of the bar rect = patches.Rectangle((box_start, idx - 0.7 / 2), box_end-box_start, 0.7, linewidth=2, edgecolor='black', facecolor='none') ax.add_patch(rect) # Customize plot ax.set_yticks([]) # Hide y-axis ticks ax.set_yticklabels([]) # Hide y-axis labels ax.set_ylim(-0.5, len(df_of_interest) - 0.5) # reduce white space at top ax.set_xlabel("Amino Acid Sequence (ordered)", fontsize=14) # Customize x-axis for labeling ax.set_xlim(left=0) # Start x-axis at 0 to make bars flush left ax.set_xlabel("Amino Acid Sequence (ordered)", fontsize=14) ax.tick_params(axis='x', labelsize=30) plt.title(f"{fusiongene} pLDDT Distribution by Amino Acid Sequence", fontsize=16) plt.tight_layout() # Save figure fusiongene_savename = fusiongene.replace("::","-") plt.savefig(f"{save_path}/plddt_sequence_{fusiongene_savename}.png", dpi=300) plt.show() def plot_favorite_fusion_pLDDT_distribution(fusion_structure_data, fusiongene, save_path=''): """ Make a stacked bar chart of the pLDDT distribution """ set_font() # Filter for EWSR1::FLI1 fusion data and preprocess df_of_interest = fusion_structure_data[fusion_structure_data['FusionGene'] == fusiongene].copy() df_of_interest['Fusion_AA_pLDDTs'] = df_of_interest['Fusion_AA_pLDDTs'].apply(lambda x: [float(i) for i in x.split(',')]) df_of_interest['Label'] = df_of_interest['Fusion_Length'].astype(str) + 'AAs' # Sort data by Fusion_Length df_of_interest = df_of_interest.sort_values(by='Fusion_Length', ascending=True).reset_index(drop=True) # Convert to dictionary format data_dict = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_AA_pLDDTs'])) average_plddt = dict(zip(df_of_interest['Label'], df_of_interest['Fusion_pLDDT'])) # Categorize each structure categorized_data = {structure: categorize_plddt(plddt_values) for structure, plddt_values in data_dict.items()} # Extract counts for each category labels = list(categorized_data.keys()) categories = ["<= 50", "50-70", "70-90", "> 90"] counts = {cat: [categorized_data[structure][cat] for structure in labels] for cat in categories} # Define colors for each category category_colors = {"<= 50": "#f27842", "50-70": "#f8d514", "70-90": "#60c1e8", "> 90": "#004ecb"} # Re-categorize PLDDT values for the bar chart categorized_data = {structure: categorize_plddt(plddt_values) for structure, plddt_values in data_dict.items()} labels = list(categorized_data.keys()) counts = {cat: [categorized_data[structure][cat] for structure in labels] for cat in categories} # Plotting the horizontal stacked bar chart with annotations for 'Fusion_pLDDT' values fig, ax = plt.subplots(figsize=(10, 6)) if len(data_dict)<3: fig, ax = plt.subplots(figsize=(10, 2)) bottom = np.zeros(len(labels)) # Stack each category horizontally for cat in categories: ax.barh(labels, counts[cat], label=cat, color=category_colors[cat], left=bottom) bottom += counts[cat] # Update the left position for the next stack # Annotate each bar with the Fusion_pLDDT value on the right, colored by PLDDT category for idx, label in enumerate(labels): avg_plddt_value = average_plddt[label] # Determine color based on the PLDDT category if avg_plddt_value > 90: color = '#004ecb' elif avg_plddt_value > 70: color = "#60c1e8" elif avg_plddt_value > 50: color = '#f8d514' else: color = '#f27842' # Annotate with the determined color #ax.text(bottom[idx] + 1, idx, f"{avg_plddt_value:.2f}", va='center', ha='left', color="black", fontsize=18, fontweight='bold') if len(df_of_interest)>10: markersize = 10 elif len(df_of_interest)>5: markersize = 16 else: markersize=12 ax.plot(bottom[idx] + .02*max(df_of_interest['Fusion_Length']), idx, marker='s', color="black", markersize=markersize, markerfacecolor=color, markeredgewidth=2) # Add labels and legend #ax.set_xlim([0,max(df_of_interest['Fusion_Length'])*1.0]) #ax.set_ylabel("Structures") # Save original ticks before changing label size #ax.tick_params(axis='x', labelsize=16) #original_xticks = ax.get_xticks() # Set ticks explicitly to avoid automatic adjustment #ax.set_xticks(original_xticks) #ax.set_xlabel("Length",fontsize=40) ax.tick_params(axis='x', labelsize=30) #ax.tick_params(axis='y', labelsize=16) ax.tick_params(axis='y', left=False, labelleft=False) #ax.set_title(f"{fusiongene} pLDDT Distribution") #ax.legend(title="pLDDT Ranges", fontsize=16, bbox_to_anchor=(1, 1), title_fontsize=16) plt.tight_layout() fusiongene_savename = fusiongene.replace("::","-") plt.savefig(f"{save_path}/plddt_dist_{fusiongene_savename}.png",dpi=300) def make_all_favorite_fusion_pLDDT_plots(favorite_fusions,left_to_right=True): fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv') swissprot_top_alignments = pd.read_csv("../../data/blast/blast_outputs/swissprot_top_alignments.csv") fuson_db = pd.read_csv("../../data/fuson_db.csv") seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id'])) fusion_structure_data['seq_id'] = fusion_structure_data['Fusion_Seq'].map(seq_id_dict) fusion_structure_data = pd.merge( fusion_structure_data, swissprot_top_alignments, on="seq_id", how="left" ) for x in favorite_fusions: if left_to_right: plot_fusion_sequence_pLDDT_left_to_right(fusion_structure_data, x, save_path='processed_data/figures/fusion_disorder') else: plot_favorite_fusion_pLDDT_distribution(fusion_structure_data, x, save_path='processed_data/figures/fusion_disorder') def prep_data_for_ht_disorder_comparison(): ht_structure_data = pd.read_csv('processed_data/fusionpdb/heads_tails_structural_data.csv') fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv') fusion_heads_and_tails = pd.read_csv('processed_data/fusionpdb/fusion_heads_and_tails.csv') all_hts_with_structures = ht_structure_data['UniProtID'].unique().tolist() fuson_ht_db = pd.read_csv('../../data/blast/fuson_ht_db.csv')[['seq_id','aa_seq','fusiongenes','hgUniProt','tgUniProt']] merge = pd.merge( fuson_ht_db.rename(columns={'aa_seq':'Fusion_Seq'}), fusion_structure_data[['FusionGID', 'Fusion_Seq','Fusion_pLDDT','Fusion_AA_pLDDTs']], on='Fusion_Seq', how='right' ) # now merge again merge['hgUniProt'] = merge['hgUniProt'].apply(lambda x: x.split(',')) merge['tgUniProt'] = merge['tgUniProt'].apply(lambda x: x.split(',')) merge = merge.explode('hgUniProt') merge = merge.explode('tgUniProt') merge = merge.loc[ merge['hgUniProt'].isin(all_hts_with_structures) & merge['tgUniProt'].isin(all_hts_with_structures) ].reset_index(drop=True) merge = pd.merge( merge, ht_structure_data.rename(columns= {'UniProtID':'hgUniProt', 'Avg pLDDT': 'hg_pLDDT', 'All pLDDTs': 'hg_AA_pLDDTs', 'Seq': 'hg_seq'}), on='hgUniProt', how='inner' ) merge = pd.merge( merge, ht_structure_data.rename(columns= {'UniProtID':'tgUniProt', 'Avg pLDDT': 'tg_pLDDT', 'All pLDDTs': 'tg_AA_pLDDTs', 'Seq': 'tg_seq'}), on='tgUniProt', how='inner' ) merge = merge.loc[merge['hg_AA_pLDDTs'].notna()] merge = merge.loc[merge['tg_AA_pLDDTs'].notna()].reset_index(drop=True) # finally, calcualte label merge['hg_label'] = merge['hg_AA_pLDDTs'].apply(lambda x: x.split(',')) merge['hg_label'] = merge['hg_label'].apply(lambda x: [float(y) for y in x]) merge['hg_label'] = merge['hg_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x]) merge['hg_label'] = merge['hg_label'].apply(lambda x: ''.join(x)) merge['tg_label'] = merge['tg_AA_pLDDTs'].apply(lambda x: x.split(',')) merge['tg_label'] = merge['tg_label'].apply(lambda x: [float(y) for y in x]) merge['tg_label'] = merge['tg_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x]) merge['tg_label'] = merge['tg_label'].apply(lambda x: ''.join(x)) merge['fusion_label'] = merge['Fusion_AA_pLDDTs'].apply(lambda x: x.split(',')) merge['fusion_label'] = merge['fusion_label'].apply(lambda x: [float(y) for y in x]) merge['fusion_label'] = merge['fusion_label'].apply(lambda x: [apply_plddt_thresh(y) for y in x]) merge['fusion_label'] = merge['fusion_label'].apply(lambda x: ''.join(x)) return merge def apply_plddt_thresh(y): if y < 68.8: return '1' else: return '0' def plot_fusion_stats_boxplots(data, save_path="fusion_disorder_boxplots.png"): set_font() # Create box plots plt.figure(figsize=(6, 5)) # for ones that are 100% disordered, AUROC was NaN, so drop these box = plt.boxplot([data[col].dropna() for col in data.columns], labels=data.columns, patch_artist=True) # Set color of each box plot for patch in box['boxes']: patch.set_facecolor('#ff68b4') patch.set_edgecolor('#ff68b4') # Customize other elements if needed #for whisker in box['whiskers']: #whisker.set_color('#ff68b4') #for cap in box['caps']: #cap.set_color('#ff68b4') for median in box['medians']: median.set_color('black') # Add labels and title #plt.xlabel('Metrics') #plt.ylabel('Values') plt.title(f"Per-Residue Disorder (n={len(data)})",fontsize=22) plt.xticks(rotation=20,fontsize=22) plt.yticks(fontsize=22) # Show plot plt.tight_layout() plt.show() plt.savefig(save_path,dpi=300) def plot_fusion_frac_disorder_r2(actual_values, predicted_values, save_path="fusion_pred_disorder_r2.png"): set_font() plt.figure(figsize=(6, 6)) r2 = r2_score(actual_values, predicted_values) #sns.kdeplot(actual_values, label="Actual Values", shade=True) #sns.kdeplot(predicted_values, label="Predicted Values", shade=True) plt.scatter(actual_values, predicted_values, alpha=0.5, label=f"Predictions", color="#ff68b4") plt.plot([min(actual_values), max(actual_values)], [min(actual_values), max(actual_values)], 'k--', label='Ideal Fit') plt.text(0, 92, f"$R^2$={r2:.2f}", fontsize=32) # Adjusting font sizes and setting font properties plt.xlabel(f'AlphaFold-pLDDT',size=32) plt.ylabel(f'FusOn-pLM-Diso',size=32) plt.title(f"% Disordered (n={len(actual_values)})",size=32) plt.xticks(fontsize=24) plt.yticks(fontsize=24) #plt.xlabel("Values") #plt.ylabel("Density") #plt.title(f"Density Plot of Actual vs Predicted Values (R^2 = {r2:.2f})") plt.legend(prop={'size': 16}) plt.tight_layout() plt.show() plt.savefig(save_path, dpi=300) def main(): set_font() #output_dir = "results/test" output_dir = "results/final" seq_label_dict = pd.read_csv('splits/test_df.csv') seq_ids_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['IDs'])) seq_label_dict = dict(zip(seq_label_dict['Sequence'],seq_label_dict['Label'])) best_caid_model_results = pd.read_csv(f"{output_dir}/best_caid_model_results.csv") make_auroc_curve(results_dir=output_dir, seq_label_dict=seq_label_dict, seq_ids_dict=seq_ids_dict, path_to_results_of_interest="trained_models/fuson_plm/best/caid_hyperparam_screen_test_probs.csv", model_alias="FusOn-pLM", path_to_esm_results="trained_models/esm2_t33_650M_UR50D/best/caid_hyperparam_screen_test_probs.csv", with_rankings=True) caid2_test_data = pd.read_csv(f"splits/splits.csv") caid2_test_data = caid2_test_data.loc[caid2_test_data['Split']=='Test'] caid2_test_labels = caid2_test_data['Label'].tolist() caid2_test_ids = caid2_test_data['IDs'].tolist() # fusions, heads, and tails fusion_ht_data = prep_data_for_ht_disorder_comparison() os.makedirs("processed_data/figures",exist_ok=True) head_data = fusion_ht_data.drop_duplicates(['hg_seq']).reset_index(drop=True) head_labels = head_data['hg_label'].tolist() head_ids = head_data['hgUniProt'].tolist() tail_data = fusion_ht_data.drop_duplicates(['tg_seq']).reset_index(drop=True) tail_labels = tail_data['tg_label'].tolist() tail_ids = tail_data['tgUniProt'].tolist() fusion_data = fusion_ht_data.drop_duplicates(['Fusion_Seq']).reset_index(drop=True) fusion_labels = fusion_data['fusion_label'].tolist() fusion_ids = fusion_data['seq_id'].tolist() plt.rc('text', usetex=False) math_part = r"$n$" os.makedirs("processed_data/figures/histograms",exist_ok=True) plot_disorder_content_hist(caid2_test_labels, caid2_test_ids, title=f"CAID2 Disorder-NOX ({math_part}={len(caid2_test_labels):,})", color="black", savepath='processed_data/figures/histograms/disorder_nox_histogram.png') plot_disorder_content_hist(head_labels, head_ids, title=f"Head Proteins ({math_part}={len(head_labels):,})", color="#df8385", savepath='processed_data/figures/histograms/heads_histogram.png') plot_disorder_content_hist(tail_labels, tail_ids, title=f"Tail Proteins ({math_part}={len(tail_labels):,})", color="#6ea4da", savepath='processed_data/figures/histograms/tails_histogram.png') plot_disorder_content_hist(fusion_labels, fusion_ids, title=f"Fusion Oncoproteins ({math_part}={len(fusion_labels):,})", color="mediumpurple", savepath='processed_data/figures/histograms/fusions_histogram.png') os.makedirs("processed_data/figures/fusion_disorder",exist_ok=True) make_all_favorite_fusion_pLDDT_plots([ "EWSR1::FLI1", "PAX3::FOXO1", "EML4::ALK", "SS18::SSX1"], left_to_right=True) if __name__ == "__main__": main()