svincoff's picture
fixed READMEs and added IDR Prediction benchmark
e048d40
from fuson_plm.utils.logging import open_logfile, log_update
from fuson_plm.utils.visualizing import set_font
from fuson_plm.benchmarking.idr_prediction.config import SPLIT
from fuson_plm.utils.splitting import split_clusters, check_split_validity
import os
import pandas as pd
def get_training_dfs(train, val, test, idr_db):
"""
Remove unnecessary columns for efficient storing of train, validation, and test sets for benchmarking.
Also, add the values using idr_db
"""
log_update('\nMaking dataframes for IDR prediction benchmark...')
# Delete cluster-related columns we don't need
train = train.drop(columns=['representative seq_id','member seq_id', 'member length', 'representative seq']).rename(columns={'member seq':'Sequence'})
val = val.drop(columns=['representative seq_id','member seq_id', 'member length', 'representative seq']).rename(columns={'member seq':'Sequence'})
test = test.drop(columns=['representative seq_id','member seq_id', 'member length', 'representative seq']).rename(columns={'member seq':'Sequence'})
# Add values and make one df for each one
# idr_db values are in columns: asph,scaled_re,scaled_rg,scaling_exp
value_cols = ['asph','scaled_re','scaled_rg','scaling_exp']
return_dict = {}
for col in value_cols:
temp_train = pd.merge(train, idr_db[['Sequence',col]], on='Sequence',how='left').rename(columns={col:'Value'}).dropna(subset='Value')
temp_val = pd.merge(val, idr_db[['Sequence',col]], on='Sequence',how='left').rename(columns={col:'Value'}).dropna(subset='Value')
temp_test = pd.merge(test, idr_db[['Sequence',col]], on='Sequence',how='left').rename(columns={col:'Value'}).dropna(subset='Value')
return_dict[col] = {
'train': temp_train,
'val': temp_val,
'test': temp_test
}
return return_dict
def main():
"""
"""
# Read all the input files
LOG_PATH = "splitting_log.txt"
IDR_DB_PATH = SPLIT.IDR_DB_PATH
CLUSTER_OUTPUT_PATH = SPLIT.CLUSTER_OUTPUT_PATH
RANDOM_STATE_1 = SPLIT.RANDOM_STATE_1
TEST_SIZE_1 = SPLIT.TEST_SIZE_1
RANDOM_STATE_2 = SPLIT.RANDOM_STATE_2
TEST_SIZE_2 = SPLIT.TEST_SIZE_2
# set font
set_font()
# Prepare the log file
with open_logfile(LOG_PATH):
log_update("Loaded data-splitting configurations from config.py")
SPLIT.print_config(indent='\t')
# Prepare directory to save results
os.makedirs("splits",exist_ok=True)
# Read the clusters and get a list of the representative IDs for splitting
clusters = pd.read_csv(CLUSTER_OUTPUT_PATH)
reps = clusters['representative seq_id'].unique().tolist()
log_update(f"\nPreparing clusters...\n\tCollected {len(reps)} clusters for splitting")
# Make the splits and extract the results
splits = split_clusters(reps,
random_state_1 = RANDOM_STATE_1, test_size_1 = TEST_SIZE_1,
random_state_2= RANDOM_STATE_2, test_size_2 = TEST_SIZE_2)
X_train = splits['X_train']
X_val = splits['X_val']
X_test = splits['X_test']
# Make slices of clusters dataframe for train, val, and test
train_clusters = clusters.loc[clusters['representative seq_id'].isin(X_train)].reset_index(drop=True)
val_clusters = clusters.loc[clusters['representative seq_id'].isin(X_val)].reset_index(drop=True)
test_clusters = clusters.loc[clusters['representative seq_id'].isin(X_test)].reset_index(drop=True)
# Check validity
check_split_validity(train_clusters, val_clusters, test_clusters)
# Print min and max sequence lengths
min_train_seqlen = min(train_clusters['member seq'].str.len())
max_train_seqlen = max(train_clusters['member seq'].str.len())
min_val_seqlen = min(val_clusters['member seq'].str.len())
max_val_seqlen = max(val_clusters['member seq'].str.len())
min_test_seqlen = min(test_clusters['member seq'].str.len())
max_test_seqlen = max(test_clusters['member seq'].str.len())
log_update(f"\nLength breakdown summary...\n\tTrain: min seq length = {min_train_seqlen}, max seq length = {max_train_seqlen}")
log_update(f"\nVal: min seq length = {min_val_seqlen}, max seq length = {max_val_seqlen}")
log_update(f"\nTest: min seq length = {min_test_seqlen}, max seq length = {max_test_seqlen}")
# cols = representative seq_id,member seq_id,representative seq,member seq
train_clusters.to_csv("splits/train_cluster_split.csv",index=False)
val_clusters.to_csv("splits/val_cluster_split.csv",index=False)
test_clusters.to_csv("splits/test_cluster_split.csv",index=False)
log_update('\nSaved cluster splits to splits/train_cluster_split.csv, splits/val_cluster_split.csv, splits/test_cluster_split.csv')
cols=','.join(list(train_clusters.columns))
log_update(f'\tColumns: {cols}')
# Get final dataframes for training, and check their distributions
idr_db = pd.read_csv(IDR_DB_PATH)
train_dfs_dict = get_training_dfs(train_clusters, val_clusters, test_clusters, idr_db)
os.makedirs('splits',exist_ok=True)
train_test_values_dict = {}
idr_property_name_dict = {'asph':'Asphericity','scaled_re':'End-to-End Distance (Re)','scaled_rg':'Radius of Gyration (Rg)','scaling_exp':'Scaling Exponent'}
for idr_property, dfs in train_dfs_dict.items():
os.makedirs(f"splits/{idr_property}", exist_ok=True)
train_df = dfs['train']
val_df = dfs['val']
test_df = dfs['test']
total_seqs = len(train_df)+len(val_df)+len(test_df)
train_df.to_csv(f"splits/{idr_property}/train_df.csv",index=False)
val_df.to_csv(f"splits/{idr_property}/val_df.csv",index=False)
test_df.to_csv(f"splits/{idr_property}/test_df.csv",index=False)
log_update(f"\nSaved {idr_property} training dataframes to splits/{idr_property}/train_df.csv, splits/{idr_property}/val_df.csv splits/test_df.csv")
log_update(f"\tTrain sequences: {len(train_df)} ({100*len(train_df)/total_seqs:.2f}%)")
log_update(f"\tVal sequences: {len(val_df)} ({100*len(val_df)/total_seqs:.2f}%)")
log_update(f"\tTest sequences: {len(test_df)} ({100*len(test_df)/total_seqs:.2f}%)")
log_update(f"\tTotal: {total_seqs}")
# Make sure the lengths are right
log_update(len(idr_db[idr_db[idr_property].notna()]))
assert total_seqs == len(idr_db[idr_db[idr_property].notna()])
train_test_values_dict[
idr_property_name_dict[idr_property]
] = {
'train': train_df['Value'].tolist(),
'val': val_df['Value'].tolist(),
'test': test_df['Value'].tolist()
}
if __name__ == "__main__":
main()