import matplotlib.pyplot as plt |
import seaborn as sns |
import pandas as pd |
import numpy as np |
import os |
from sklearn.metrics import r2_score |
import matplotlib.colors as mcolors |
from fuson_plm.utils.visualizing import set_font |
global default_cmap_dict |
default_cmap_dict = { |
'Asphericity': '#785EF0', |
'End-to-End Distance (Re)': '#DC267F', |
'Radius of Gyration (Rg)': '#FE6100', |
'Scaling Exponent': '#FFB000' |
} |
def lengthen_model_name(model_name, model_epoch): |
if 'esm' in model_name: |
return model_name |
return f'{model_name}_e{model_epoch}' |
def plot_train_val_test_values_hist(train_values_list, val_values_list, test_values_list, dataset_name="Data", color="black", save_path=None, ax=None): |
""" |
Plot Histogram to show the ranges of values |
""" |
set_font() |
if ax is None: |
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300) |
total_seqs = len(train_values_list)+len(val_values_list)+len(test_values_list) |
ax.hist(train_values_list, color=color, alpha=0.7,label=f"train (n={len(train_values_list)})") |
if not(test_values_list is None): |
ax.hist(test_values_list, color='black',alpha=0.7,label=f"test (n={len(test_values_list)})") |
if not(val_values_list is None): |
ax.hist(val_values_list, color='grey',alpha=0.7,label=f"val (n={len(val_values_list)})") |
ax.grid(True) |
ax.set_axisbelow(True) |
ax.set_title(f'{dataset_name} Distribution (n={total_seqs})') |
ax.set_xlabel(dataset_name) |
ax.legend() |
plt.tight_layout() |
if save_path is not None: |
plt.savefig(save_path) |
def plot_values_hist(values_list, dataset_name="Data", color="black", save_path=None, ax=None): |
""" |
Plot Histogram to show the ranges of values |
""" |
set_font() |
if ax is None: |
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300) |
ax.hist(values_list, color=color) |
ax.grid(True) |
ax.set_axisbelow(True) |
ax.set_title(f'{dataset_name} Distribution') |
ax.set_xlabel(dataset_name) |
plt.tight_layout() |
if save_path is not None: |
plt.savefig(save_path) |
def plot_all_values_hist_grid(values_dict, cmap_dict=default_cmap_dict, save_path="processed_data/value_histograms.png"): |
""" |
Args: |
values_dict: dictionary where keys are dataset names and values are value lists |
cmap_dict: dictioanry where keys are dataset names (same as in values dict) and values are value lists |
""" |
fig, axes = plt.subplots(2, 2, figsize=(12, 8), dpi=300) |
axes = axes.flatten() |
for i, (dataset_name, values_list) in enumerate(values_dict.items()): |
ax = axes[i] |
plot_values_hist(values_list, dataset_name=dataset_name, color=cmap_dict[dataset_name], ax=ax) |
fig.set_tight_layout(True) |
fig.savefig(save_path) |
def plot_all_train_val_test_values_hist_grid(values_dict, cmap_dict=default_cmap_dict, save_path="processed_data/value_histograms.png"): |
""" |
Args: |
values_dict: dictionary where keys are dataset names and values are another dict: {'train': train_values_list, 'test': test_values_list} |
cmap_dict: dictioanry where keys are dataset names (same as in values dict) and values are value lists |
""" |
fig, axes = plt.subplots(2, 2, figsize=(12, 8), dpi=300) |
axes = axes.flatten() |
for i, (dataset_name, train_val_test_dict) in enumerate(values_dict.items()): |
ax = axes[i] |
train_values_list = train_val_test_dict['train'] |
test_values_list, val_values_list = None, None |
if 'test' in train_val_test_dict: |
test_values_list = train_val_test_dict['test'] |
if 'val' in train_val_test_dict: |
val_values_list = train_val_test_dict['val'] |
plot_train_val_test_values_hist(train_values_list, val_values_list, test_values_list, dataset_name=dataset_name, color=cmap_dict[dataset_name], ax=ax) |
fig.set_tight_layout(True) |
fig.savefig(save_path) |
def plot_r2(model_type, idr_property, test_preds, save_path): |
set_font() |
ylabel_dict = {'asph': 'Asphericity', |
'scaled_re': 'End-to-End Radius, $R_e$', |
'scaled_rg': 'Radius of Gyration, $R_g$', |
'scaling_exp': 'Polymer Scaling Exponent'} |
y_unitlabel_dict = {'asph': 'Asphericity', |
'scaled_re': '$R_e$ (Å)', |
'scaled_rg': '$R_g$ (Å)', |
'scaling_exp': 'Exponent' |
} |
y_label = ylabel_dict[idr_property] |
y_unitlabel = y_unitlabel_dict[idr_property] |
true_values = test_preds['true_values'].tolist() |
predictions = test_preds['predictions'].tolist() |
test_df = pd.read_csv(f"splits/{idr_property}/test_df.csv") |
processed_data = pd.read_csv("processed_data/all_albatross_seqs_and_properties.csv") |
seq_id_dict = dict(zip(processed_data['Sequence'],processed_data['IDs'])) |
test_df['IDs'] = test_df['Sequence'].map(seq_id_dict) |
test_df_with_preds = test_preds[['true_values','predictions']] |
test_df_with_preds['IDs'] = test_df['IDs'] |
print("number of sequences with no ID: ", len(test_df_with_preds.loc[test_df_with_preds['IDs'].isna()])) |
test_df_with_preds.to_csv(save_path.replace(".png","_source_data.csv"),index=False) |
r2 = r2_score(true_values, predictions) |
plt.figure(figsize=(10, 8)) |
plt.scatter(true_values, predictions, alpha=0.5, label='Predictions') |
plt.plot([min(true_values), max(true_values)], [min(true_values), max(true_values)], 'r--', label='Ideal Fit') |
plt.text(0.65, 0.35, f"$R^2$ = {r2:.2f}", transform=plt.gca().transAxes, fontsize=44) |
plt.xlabel(f'True {y_unitlabel}',size=44) |
plt.ylabel(f'Predicted {y_unitlabel}',size=44) |
plt.title(f"{y_label}",size=50) |
legend = plt.legend(fontsize=32) |
for text in legend.get_texts(): |
text.set_fontsize(32) |
for handle in legend.legendHandles: |
handle._sizes = [100] |
plt.grid(True) |
plt.xticks(fontsize=36) |
plt.yticks(fontsize=36) |
for label in plt.gca().get_xticklabels(): |
label.set_fontsize(32) |
for label in plt.gca().get_yticklabels(): |
label.set_fontsize(32) |
plt.tight_layout() |
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
plt.show() |
def plot_all_r2(output_dir, idr_properties): |
for idr_property in idr_properties: |
best_results = pd.read_csv(f"{output_dir}/{idr_property}_best_test_r2.csv") |
model_type_to_path_dict = dict(zip(best_results['model_type'],best_results['path_to_model'])) |
for model_type, path_to_model in model_type_to_path_dict.items(): |
model_preds_folder = path_to_model.split('/best-checkpoint.ckpt')[0] |
test_preds = pd.read_csv(f"{model_preds_folder}/{idr_property}_test_predictions.csv") |
if not os.path.exists(f"{output_dir}/r2_plots"): |
os.makedirs(f"{output_dir}/r2_plots") |
os.makedirs(f"{output_dir}/r2_plots/{idr_property}", exist_ok=True) |
model_type_dict = { |
'fuson_plm': 'FusOn-pLM', |
'esm2_t33_650M_UR50D': 'ESM-2' |
} |
r2_save_path = f"{output_dir}/r2_plots/{idr_property}/{model_type}_{idr_property}_R2.png" |
plot_r2(model_type_dict[model_type], idr_property, test_preds, r2_save_path) |
def main(): |
plot_all_r2("results/final", ["asph","scaled_re","scaled_rg","scaling_exp"]) |
if __name__ == '__main__': |
main() |