import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import numpy as np import os from sklearn.metrics import r2_score import matplotlib.colors as mcolors from fuson_plm.utils.visualizing import set_font global default_cmap_dict default_cmap_dict = { 'Asphericity': '#785EF0', 'End-to-End Distance (Re)': '#DC267F', 'Radius of Gyration (Rg)': '#FE6100', 'Scaling Exponent': '#FFB000' } # Method for lengthening the model name def lengthen_model_name(model_name, model_epoch): if 'esm' in model_name: return model_name return f'{model_name}_e{model_epoch}' def plot_train_val_test_values_hist(train_values_list, val_values_list, test_values_list, dataset_name="Data", color="black", save_path=None, ax=None): """ Plot Histogram to show the ranges of values """ set_font() if ax is None: fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300) total_seqs = len(train_values_list)+len(val_values_list)+len(test_values_list) ax.hist(train_values_list, color=color, alpha=0.7,label=f"train (n={len(train_values_list)})") if not(test_values_list is None): ax.hist(test_values_list, color='black',alpha=0.7,label=f"test (n={len(test_values_list)})") if not(val_values_list is None): ax.hist(val_values_list, color='grey',alpha=0.7,label=f"val (n={len(val_values_list)})") ax.grid(True) ax.set_axisbelow(True) ax.set_title(f'{dataset_name} Distribution (n={total_seqs})') ax.set_xlabel(dataset_name) ax.legend() plt.tight_layout() if save_path is not None: plt.savefig(save_path) def plot_values_hist(values_list, dataset_name="Data", color="black", save_path=None, ax=None): """ Plot Histogram to show the ranges of values """ set_font() if ax is None: fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300) ax.hist(values_list, color=color) ax.grid(True) ax.set_axisbelow(True) ax.set_title(f'{dataset_name} Distribution') ax.set_xlabel(dataset_name) plt.tight_layout() if save_path is not None: plt.savefig(save_path) def plot_all_values_hist_grid(values_dict, cmap_dict=default_cmap_dict, save_path="processed_data/value_histograms.png"): """ Args: values_dict: dictionary where keys are dataset names and values are value lists cmap_dict: dictioanry where keys are dataset names (same as in values dict) and values are value lists """ fig, axes = plt.subplots(2, 2, figsize=(12, 8), dpi=300) axes = axes.flatten() for i, (dataset_name, values_list) in enumerate(values_dict.items()): ax = axes[i] plot_values_hist(values_list, dataset_name=dataset_name, color=cmap_dict[dataset_name], ax=ax) fig.set_tight_layout(True) fig.savefig(save_path) def plot_all_train_val_test_values_hist_grid(values_dict, cmap_dict=default_cmap_dict, save_path="processed_data/value_histograms.png"): """ Args: values_dict: dictionary where keys are dataset names and values are another dict: {'train': train_values_list, 'test': test_values_list} cmap_dict: dictioanry where keys are dataset names (same as in values dict) and values are value lists """ fig, axes = plt.subplots(2, 2, figsize=(12, 8), dpi=300) axes = axes.flatten() for i, (dataset_name, train_val_test_dict) in enumerate(values_dict.items()): ax = axes[i] train_values_list = train_val_test_dict['train'] test_values_list, val_values_list = None, None if 'test' in train_val_test_dict: test_values_list = train_val_test_dict['test'] if 'val' in train_val_test_dict: val_values_list = train_val_test_dict['val'] plot_train_val_test_values_hist(train_values_list, val_values_list, test_values_list, dataset_name=dataset_name, color=cmap_dict[dataset_name], ax=ax) fig.set_tight_layout(True) fig.savefig(save_path) #only need to change labels at bottom depending on what embeddings+dimension is being looked at def plot_r2(model_type, idr_property, test_preds, save_path): set_font() # prepare ylabels from idr_property ylabel_dict = {'asph': 'Asphericity', 'scaled_re': 'End-to-End Radius, $R_e$', 'scaled_rg': 'Radius of Gyration, $R_g$', 'scaling_exp': 'Polymer Scaling Exponent'} y_unitlabel_dict = {'asph': 'Asphericity', 'scaled_re': '$R_e$ (Å)', 'scaled_rg': '$R_g$ (Å)', 'scaling_exp': 'Exponent' } y_label = ylabel_dict[idr_property] y_unitlabel = y_unitlabel_dict[idr_property] # get true values and predictions true_values = test_preds['true_values'].tolist() predictions = test_preds['predictions'].tolist() # save this source data, including the IDs of the sequences test_df = pd.read_csv(f"splits/{idr_property}/test_df.csv") processed_data = pd.read_csv("processed_data/all_albatross_seqs_and_properties.csv") seq_id_dict = dict(zip(processed_data['Sequence'],processed_data['IDs'])) test_df['IDs'] = test_df['Sequence'].map(seq_id_dict) test_df_with_preds = test_preds[['true_values','predictions']] test_df_with_preds['IDs'] = test_df['IDs'] print("number of sequences with no ID: ", len(test_df_with_preds.loc[test_df_with_preds['IDs'].isna()])) test_df_with_preds.to_csv(save_path.replace(".png","_source_data.csv"),index=False) r2 = r2_score(true_values, predictions) # Plotting plt.figure(figsize=(10, 8)) plt.scatter(true_values, predictions, alpha=0.5, label='Predictions') plt.plot([min(true_values), max(true_values)], [min(true_values), max(true_values)], 'r--', label='Ideal Fit') plt.text(0.65, 0.35, f"$R^2$ = {r2:.2f}", transform=plt.gca().transAxes, fontsize=44) # Adjusting font sizes and setting font properties plt.xlabel(f'True {y_unitlabel}',size=44) plt.ylabel(f'Predicted {y_unitlabel}',size=44) plt.title(f"{y_label}",size=50) #: {model_type}\n($R^2$={r2:.2f})",size=44) # Create legend and set font properties legend = plt.legend(fontsize=32) for text in legend.get_texts(): text.set_fontsize(32) # Adjust marker size in the legend for handle in legend.legendHandles: handle._sizes = [100] # Enable grid plt.grid(True) # Adjusting tick labels font size plt.xticks(fontsize=36) plt.yticks(fontsize=36) # Setting font properties for tick labels (another way to adjust them individually) for label in plt.gca().get_xticklabels(): label.set_fontsize(32) for label in plt.gca().get_yticklabels(): label.set_fontsize(32) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.show() def plot_all_r2(output_dir, idr_properties): for idr_property in idr_properties: # make the R^2 Plots for the BEST one best_results = pd.read_csv(f"{output_dir}/{idr_property}_best_test_r2.csv") model_type_to_path_dict = dict(zip(best_results['model_type'],best_results['path_to_model'])) for model_type, path_to_model in model_type_to_path_dict.items(): model_preds_folder = path_to_model.split('/best-checkpoint.ckpt')[0] test_preds = pd.read_csv(f"{model_preds_folder}/{idr_property}_test_predictions.csv") # make paths for R^2 plots if not os.path.exists(f"{output_dir}/r2_plots"): os.makedirs(f"{output_dir}/r2_plots") os.makedirs(f"{output_dir}/r2_plots/{idr_property}", exist_ok=True) model_type_dict = { 'fuson_plm': 'FusOn-pLM', 'esm2_t33_650M_UR50D': 'ESM-2' } r2_save_path = f"{output_dir}/r2_plots/{idr_property}/{model_type}_{idr_property}_R2.png" plot_r2(model_type_dict[model_type], idr_property, test_preds, r2_save_path) def main(): plot_all_r2("results/final", ["asph","scaled_re","scaled_rg","scaling_exp"]) if __name__ == '__main__': main()