Fill-Mask
Transformers
Safetensors
esm
svincoff's picture
uploading data folder
1e6a1f0
import pandas as pd
import os
import pickle
from fuson_plm.data.config import SPLIT
from fuson_plm.utils.logging import log_update, open_logfile
from fuson_plm.utils.splitting import split_clusters, check_split_validity
from fuson_plm.utils.visualizing import set_font, visualize_splits
def get_benchmark_data(fuson_db_path, clusters):
"""
"""
# Read the fusion database
fuson_db = pd.read_csv(fuson_db_path)
# Get original benchmark sequences, and benchmark sequences that were clustered
original_benchmark_sequences = fuson_db.loc[(fuson_db['benchmark'].notna()) ]
benchmark_sequences = fuson_db.loc[
(fuson_db['benchmark'].notna()) & # it's a benchmark sequence
(fuson_db['aa_seq'].isin(list(clusters['member seq']))) # it was clustered (it's under the length limit specified for clustering)
]['aa_seq'].to_list()
# Get the sequence IDs of all clustered benchmark sequences.
benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id']
# Use benchmark_seq_ids to find which clusters contain benchmark sequences.
benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist()
log_update(f"\t{len(benchmark_sequences)}/{len(original_benchmark_sequences)} benchmarking sequences (only those shorter than config.CLUSTERING[\'max_seq_length\']) were grouped into {len(benchmark_cluster_reps)} clusters. These will be reserved for the test set.")
return benchmark_cluster_reps, benchmark_sequences
def get_training_dfs(train, val, test):
log_update('\nMaking dataframes for ESM finetuning...')
# Delete cluster-related columns we don't need
train = train.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'})
val = val.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'})
test = test.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'})
return train, val, test
def main():
"""
"""
# Read all the input files
LOG_PATH = "splitting_log.txt"
FUSON_DB_PATH = SPLIT.FUSON_DB_PATH
CLUSTER_OUTPUT_PATH = SPLIT.CLUSTER_OUTPUT_PATH
RANDOM_STATE_1 = SPLIT.RANDOM_STATE_1
TEST_SIZE_1 = SPLIT.TEST_SIZE_1
RANDOM_STATE_2 = SPLIT.RANDOM_STATE_2
TEST_SIZE_2 = SPLIT.TEST_SIZE_2
# set font
set_font()
# Prepare the log file
with open_logfile(LOG_PATH):
log_update("Loaded data-splitting configurations from config.py")
SPLIT.print_config(indent='\t')
# Prepare directory to save results
os.makedirs("splits",exist_ok=True)
# Read the clusters and get a list of the representative IDs for splitting
clusters = pd.read_csv(CLUSTER_OUTPUT_PATH)
reps = clusters['representative seq_id'].unique().tolist()
log_update(f"\nPreparing clusters...\n\tCollected {len(reps)} clusters for splitting")
# Get the benchmark cluster representatives and sequences
benchmark_cluster_reps, benchmark_sequences = get_benchmark_data(FUSON_DB_PATH, clusters)
# Make the splits and extract the results
splits = split_clusters(reps, benchmark_cluster_reps=benchmark_cluster_reps,
random_state_1 = RANDOM_STATE_1, random_state_2 = RANDOM_STATE_2, test_size_1 = TEST_SIZE_1, test_size_2 = TEST_SIZE_2)
X_train = splits['X_train']
X_val = splits['X_val']
X_test = splits['X_test']
# Make slices of clusters dataframe for train, val, and test
train_clusters = clusters.loc[clusters['representative seq_id'].isin(X_train)].reset_index(drop=True)
val_clusters = clusters.loc[clusters['representative seq_id'].isin(X_val)].reset_index(drop=True)
test_clusters = clusters.loc[clusters['representative seq_id'].isin(X_test)].reset_index(drop=True)
# Check validity
check_split_validity(train_clusters, val_clusters, test_clusters, benchmark_sequences=benchmark_sequences)
# Print min and max sequence lengths
min_train_seqlen = min(train_clusters['member seq'].str.len())
max_train_seqlen = max(train_clusters['member seq'].str.len())
min_val_seqlen = min(val_clusters['member seq'].str.len())
max_val_seqlen = max(val_clusters['member seq'].str.len())
min_test_seqlen = min(test_clusters['member seq'].str.len())
max_test_seqlen = max(test_clusters['member seq'].str.len())
log_update(f"\nLength breakdown summary...\n\tTrain: min seq length = {min_train_seqlen}, max seq length = {max_train_seqlen}")
log_update(f"\tVal: min seq length = {min_val_seqlen}, max seq length = {max_val_seqlen}")
log_update(f"\tTest: min seq length = {min_test_seqlen}, max seq length = {max_test_seqlen}")
# Make plots to visualize the splits
visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps)
# cols = representative seq_id,member seq_id,representative seq,member seq
train_clusters.to_csv("../data/splits/train_cluster_split.csv",index=False)
val_clusters.to_csv("../data/splits/val_cluster_split.csv",index=False)
test_clusters.to_csv("../data/splits/test_cluster_split.csv",index=False)
log_update('\nSaved cluster splits to splitting/train_cluster_split.csv, splitting/val_cluster_split.csv, splitting/test_cluster_split.csv')
cols=','.join(list(train_clusters.columns))
log_update(f'\tColumns: {cols}')
# IF SnP vectors have been comptued already, make train_df, val_df, test_df: the data that will be input to the training script
train_df, val_df, test_df = get_training_dfs(train_clusters, val_clusters, test_clusters)
train_df.to_csv("../data/splits/train_df.csv",index=False)
val_df.to_csv("../data/splits/val_df.csv",index=False)
test_df.to_csv("../data/splits/test_df.csv",index=False)
log_update('\nSaved training dataframes to splits/train_df.csv, splits/val_df.csv, splits/test_df.csv')
cols=','.join(list(train_df.columns))
log_update(f'\tColumns: {cols}')
if __name__ == "__main__":
main()