|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
from sklearn.metrics import r2_score |
|
import matplotlib.colors as mcolors |
|
from fuson_plm.utils.visualizing import set_font |
|
|
|
global default_cmap_dict |
|
default_cmap_dict = { |
|
'Asphericity': '#785EF0', |
|
'End-to-End Distance (Re)': '#DC267F', |
|
'Radius of Gyration (Rg)': '#FE6100', |
|
'Scaling Exponent': '#FFB000' |
|
} |
|
|
|
|
|
def lengthen_model_name(model_name, model_epoch): |
|
if 'esm' in model_name: |
|
return model_name |
|
|
|
return f'{model_name}_e{model_epoch}' |
|
|
|
def plot_train_val_test_values_hist(train_values_list, val_values_list, test_values_list, dataset_name="Data", color="black", save_path=None, ax=None): |
|
""" |
|
Plot Histogram to show the ranges of values |
|
""" |
|
set_font() |
|
if ax is None: |
|
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300) |
|
|
|
total_seqs = len(train_values_list)+len(val_values_list)+len(test_values_list) |
|
ax.hist(train_values_list, color=color, alpha=0.7,label=f"train (n={len(train_values_list)})") |
|
if not(test_values_list is None): |
|
ax.hist(test_values_list, color='black',alpha=0.7,label=f"test (n={len(test_values_list)})") |
|
if not(val_values_list is None): |
|
ax.hist(val_values_list, color='grey',alpha=0.7,label=f"val (n={len(val_values_list)})") |
|
ax.grid(True) |
|
ax.set_axisbelow(True) |
|
ax.set_title(f'{dataset_name} Distribution (n={total_seqs})') |
|
ax.set_xlabel(dataset_name) |
|
ax.legend() |
|
plt.tight_layout() |
|
|
|
if save_path is not None: |
|
plt.savefig(save_path) |
|
|
|
def plot_values_hist(values_list, dataset_name="Data", color="black", save_path=None, ax=None): |
|
""" |
|
Plot Histogram to show the ranges of values |
|
""" |
|
set_font() |
|
if ax is None: |
|
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300) |
|
|
|
ax.hist(values_list, color=color) |
|
ax.grid(True) |
|
ax.set_axisbelow(True) |
|
ax.set_title(f'{dataset_name} Distribution') |
|
ax.set_xlabel(dataset_name) |
|
plt.tight_layout() |
|
|
|
if save_path is not None: |
|
plt.savefig(save_path) |
|
|
|
def plot_all_values_hist_grid(values_dict, cmap_dict=default_cmap_dict, save_path="processed_data/value_histograms.png"): |
|
""" |
|
Args: |
|
values_dict: dictionary where keys are dataset names and values are value lists |
|
cmap_dict: dictioanry where keys are dataset names (same as in values dict) and values are value lists |
|
""" |
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 8), dpi=300) |
|
axes = axes.flatten() |
|
|
|
for i, (dataset_name, values_list) in enumerate(values_dict.items()): |
|
ax = axes[i] |
|
plot_values_hist(values_list, dataset_name=dataset_name, color=cmap_dict[dataset_name], ax=ax) |
|
|
|
fig.set_tight_layout(True) |
|
fig.savefig(save_path) |
|
|
|
|
|
def plot_all_train_val_test_values_hist_grid(values_dict, cmap_dict=default_cmap_dict, save_path="processed_data/value_histograms.png"): |
|
""" |
|
Args: |
|
values_dict: dictionary where keys are dataset names and values are another dict: {'train': train_values_list, 'test': test_values_list} |
|
cmap_dict: dictioanry where keys are dataset names (same as in values dict) and values are value lists |
|
""" |
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 8), dpi=300) |
|
axes = axes.flatten() |
|
|
|
for i, (dataset_name, train_val_test_dict) in enumerate(values_dict.items()): |
|
ax = axes[i] |
|
train_values_list = train_val_test_dict['train'] |
|
test_values_list, val_values_list = None, None |
|
if 'test' in train_val_test_dict: |
|
test_values_list = train_val_test_dict['test'] |
|
if 'val' in train_val_test_dict: |
|
val_values_list = train_val_test_dict['val'] |
|
plot_train_val_test_values_hist(train_values_list, val_values_list, test_values_list, dataset_name=dataset_name, color=cmap_dict[dataset_name], ax=ax) |
|
|
|
fig.set_tight_layout(True) |
|
fig.savefig(save_path) |
|
|
|
|
|
def plot_r2(model_type, idr_property, test_preds, save_path): |
|
set_font() |
|
|
|
|
|
ylabel_dict = {'asph': 'Asphericity', |
|
'scaled_re': 'End-to-End Radius, $R_e$', |
|
'scaled_rg': 'Radius of Gyration, $R_g$', |
|
'scaling_exp': 'Polymer Scaling Exponent'} |
|
y_unitlabel_dict = {'asph': 'Asphericity', |
|
'scaled_re': '$R_e$ (Å)', |
|
'scaled_rg': '$R_g$ (Å)', |
|
'scaling_exp': 'Exponent' |
|
|
|
} |
|
y_label = ylabel_dict[idr_property] |
|
y_unitlabel = y_unitlabel_dict[idr_property] |
|
|
|
|
|
true_values = test_preds['true_values'].tolist() |
|
predictions = test_preds['predictions'].tolist() |
|
|
|
|
|
test_df = pd.read_csv(f"splits/{idr_property}/test_df.csv") |
|
processed_data = pd.read_csv("processed_data/all_albatross_seqs_and_properties.csv") |
|
seq_id_dict = dict(zip(processed_data['Sequence'],processed_data['IDs'])) |
|
test_df['IDs'] = test_df['Sequence'].map(seq_id_dict) |
|
test_df_with_preds = test_preds[['true_values','predictions']] |
|
test_df_with_preds['IDs'] = test_df['IDs'] |
|
print("number of sequences with no ID: ", len(test_df_with_preds.loc[test_df_with_preds['IDs'].isna()])) |
|
test_df_with_preds.to_csv(save_path.replace(".png","_source_data.csv"),index=False) |
|
|
|
r2 = r2_score(true_values, predictions) |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
plt.scatter(true_values, predictions, alpha=0.5, label='Predictions') |
|
plt.plot([min(true_values), max(true_values)], [min(true_values), max(true_values)], 'r--', label='Ideal Fit') |
|
plt.text(0.65, 0.35, f"$R^2$ = {r2:.2f}", transform=plt.gca().transAxes, fontsize=44) |
|
|
|
plt.xlabel(f'True {y_unitlabel}',size=44) |
|
plt.ylabel(f'Predicted {y_unitlabel}',size=44) |
|
plt.title(f"{y_label}",size=50) |
|
|
|
|
|
legend = plt.legend(fontsize=32) |
|
for text in legend.get_texts(): |
|
text.set_fontsize(32) |
|
|
|
|
|
for handle in legend.legendHandles: |
|
handle._sizes = [100] |
|
|
|
|
|
plt.grid(True) |
|
|
|
|
|
plt.xticks(fontsize=36) |
|
plt.yticks(fontsize=36) |
|
|
|
|
|
for label in plt.gca().get_xticklabels(): |
|
label.set_fontsize(32) |
|
|
|
for label in plt.gca().get_yticklabels(): |
|
label.set_fontsize(32) |
|
|
|
plt.tight_layout() |
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
plt.show() |
|
|
|
def plot_all_r2(output_dir, idr_properties): |
|
for idr_property in idr_properties: |
|
|
|
best_results = pd.read_csv(f"{output_dir}/{idr_property}_best_test_r2.csv") |
|
model_type_to_path_dict = dict(zip(best_results['model_type'],best_results['path_to_model'])) |
|
for model_type, path_to_model in model_type_to_path_dict.items(): |
|
model_preds_folder = path_to_model.split('/best-checkpoint.ckpt')[0] |
|
test_preds = pd.read_csv(f"{model_preds_folder}/{idr_property}_test_predictions.csv") |
|
|
|
|
|
if not os.path.exists(f"{output_dir}/r2_plots"): |
|
os.makedirs(f"{output_dir}/r2_plots") |
|
os.makedirs(f"{output_dir}/r2_plots/{idr_property}", exist_ok=True) |
|
|
|
model_type_dict = { |
|
'fuson_plm': 'FusOn-pLM', |
|
'esm2_t33_650M_UR50D': 'ESM-2' |
|
} |
|
r2_save_path = f"{output_dir}/r2_plots/{idr_property}/{model_type}_{idr_property}_R2.png" |
|
plot_r2(model_type_dict[model_type], idr_property, test_preds, r2_save_path) |
|
|
|
def main(): |
|
plot_all_r2("results/final", ["asph","scaled_re","scaled_rg","scaling_exp"]) |
|
|
|
if __name__ == '__main__': |
|
main() |