svincoff's picture
puncta benchmark
8d9d9da
# Cleans raw data to prepare FO labels and embeddings
from fuson_plm.utils.logging import open_logfile, log_update
from fuson_plm.utils.data_cleaning import find_invalid_chars
from fuson_plm.utils.constants import VALID_AAS
import pandas as pd
import numpy as np
import pickle
def find_localization(row):
puncta_status = row['Puncta_Status']
cytoplasm = (row['Cytoplasm']=='Punctate')
nucleus = (row['Nucleus']=='Punctate')
both = cytoplasm and nucleus
if puncta_status=='YES':
if both:
return 'Both'
else:
if cytoplasm:
return 'Cytoplasm'
if nucleus:
return 'Nucleus'
return np.nan
def clean_s5(df):
log_update("Cleaning FOdb Supplementary Table 5")
# extract only the physicochemical features used by the FO-Puncta ML model
retained_features = df.loc[
df['Low MI Set: Used In ML Model'].isin(['Yes','Yet']) # allow flexibility for typo in this DF
]['Parameter Label (Sup Table 2 & Matlab Scripts)'].tolist()
retained_features = sorted(retained_features)
# log the result
log_update(f'\tIsolated the {len(retained_features)} low-MI features used to train ML model')
for i, feat in enumerate(retained_features): log_update(f'\t\t{i+1}. {feat}')
# return the result
return retained_features
def make_label_df(df):
"""
Input df should be cleaned s4
"""
label_df = df[['FO_Name','AAseq','Localization','Puncta_Status','Dataset']].rename(columns={'FO_Name':'fusiongene','AAseq':'aa_seq','Dataset':'dataset'})
dataset_to_split_dict = {'Expressed_Set': 'train', 'Verification_Set': 'test'}
label_df['split'] = label_df['dataset'].apply(lambda x: dataset_to_split_dict[x])
label_df['nucleus'] = label_df['Localization'].apply(lambda x: 1 if x in ['Nucleus','Both'] else 0)
label_df['cytoplasm'] = label_df['Localization'].apply(lambda x: 1 if x in ['Cytoplasm','Both'] else 0)
label_df['formation'] = label_df['Puncta_Status'].apply(lambda x: 1 if x=='YES' else 0)
label_df = label_df[['fusiongene','aa_seq','dataset','split','nucleus','cytoplasm','formation']]
return label_df
def make_embeddings(df, physicochemical_features):
feat_string = '\n\t' + '\n\t'.join([str(i)+'. '+feat for i,feat in enumerate(physicochemical_features)])
log_update(f"\nMaking phyisochemical feature vectors.\nFeature Order: {feat_string}")
embeddings = {}
aa_seqs = df['AAseq'].unique()
for seq in aa_seqs:
feats = df.loc[df['AAseq']==seq].reset_index(drop=True)[physicochemical_features].T[0].tolist()
embeddings[seq] = feats
return embeddings
def clean_s4(df, retained_features):
log_update("Cleaning FOdb Supplementary Table 4")
df = df.loc[
df['Puncta_Status'].isin(['YES','NO'])
].reset_index(drop=True)
log_update(f'\tRemoved invalid FOs (puncta status = "Other" or "Nucleolar"). Remaining FOs: {len(df)}')
# check for duplicate sequences
dup_seqs = df.loc[df['AAseq'].duplicated()]['AAseq'].unique()
log_update(f"\tTotal duplicated sequences: {len(dup_seqs)}")
# check for invalid characters
df['invalid_chars'] = df['AAseq'].apply(lambda x: find_invalid_chars(x, VALID_AAS))
all_invalid_chars = set().union(*df['invalid_chars'])
log_update(f"\tChecking for invalid characters...\n\t\tFound {len(all_invalid_chars)} invalid characters")
for c in all_invalid_chars:
subset = df.loc[df['AAseq'].str.contains(c)]['AAseq'].tolist()
for seq in subset:
log_update(f"\t\tInvalid char {c} at index {seq.index(c)}/{len(seq)-1} of sequence {seq}")
# going to just remove the "-" from the special sequence
df = df.drop(columns=['invalid_chars'])
df.loc[
df['AAseq'].str.contains('-'),'AAseq'
] = df.loc[df['AAseq'].str.contains('-'),'AAseq'].item().replace('-','')
# change FO format to ::
df['FO_Name'] = df['FO_Name'].apply(lambda x: x.replace('_','::'))
log_update(f'\tChanged FO names to Head::Tail format')
# Isolate positive and negative sets
df['Localization'] = ['']*len(df)
df['Localization'] = df.apply(lambda row: find_localization(row), axis=1)
puncta_positive = df.loc[
df['Puncta_Status']=='YES'
].reset_index(drop=True)
puncta_negative = df.loc[
df['Puncta_Status']=='NO'
].reset_index(drop=True)
# Only keeping retained features
cols = list(df.columns)
mi_feats_included = set(retained_features).intersection(set(cols))
log_update(f"\tChecking for the {len(retained_features)} low-MI features... {len(mi_feats_included)} found")
# make sure all of these are no-na
for rf in retained_features:
# if there's NaN, log it. Make sure the only instances of np.nan are for Verification Set FOs.
if df[rf].isna().sum()>0:
nas = df.loc[df[rf].isna()]
log_update(f"\t\tFeature {rf} has {len(nas)} np.nan values in the following datasets:")
for k,v in nas['Dataset'].value_counts().items():
print(f'\t\t\t{k}: {v}')
df = df[['FO_Name', 'Nucleus', 'Nucleolus', 'Cytoplasm','Puncta_Status', 'Dataset', 'Localization', 'AAseq',
'Puncta.pred', 'Puncta.prob']+retained_features]
# Quantify localization
log_update(f'\n\tPuncta localization for {len(puncta_positive)} FOs where Puncta_Status==YES')
for k, v in puncta_positive['Localization'].value_counts().items():
pcnt = 100*v/sum(puncta_positive['Localization'].value_counts())
log_update(f'\t\t{k}: \t{v} ({pcnt:.2f}%)')
log_update("\tDataset breakdown...")
dataset_vc = df['Dataset'].value_counts()
expressed_puncta_statuses = df.loc[df['Dataset']=='Expressed_Set']['Puncta_Status'].value_counts()
expressed_positive_locs = puncta_positive.loc[puncta_positive['Dataset']=='Expressed_Set']['Localization'].value_counts()
verification_positive_locs = puncta_positive.loc[puncta_positive['Dataset']=='Verification_Set']['Localization'].value_counts()
verification_puncta_statuses = df.loc[df['Dataset']=='Verification_Set']['Puncta_Status'].value_counts()
for k, v in dataset_vc.items():
pcnt = 100*v/sum(dataset_vc)
log_update(f'\t\t{k}: \t{v} ({pcnt:.2f}%)')
if k=='Expressed_Set':
for key, val in expressed_puncta_statuses.items():
pcnt = 100*val/v
log_update(f'\t\t\t{key}: \t{val} ({pcnt:.2f}%)')
if key=='YES':
log_update('\t\t\t\tLocalizations...')
for key2, val2 in expressed_positive_locs.items():
pcnt = 100*val2/val
log_update(f'\t\t\t\t\t{key2}: \t{val2} ({pcnt:.2f}%)')
if k=='Verification_Set':
for key, val in verification_puncta_statuses.items():
pcnt = 100*val/v
log_update(f'\t\t\t{key}: \t{val} ({pcnt:.2f}%)')
if key=='YES':
log_update('\t\t\t\tLocalizations...')
for key2, val2 in verification_positive_locs.items():
pcnt = 100*val2/val
log_update(f'\t\t\t\t\t{key2}: \t{val2} ({pcnt:.2f}%)')
return df
def main():
LOG_PATH = 'cleaning_log.txt'
FODB_S4_PATH = '../../data/raw_data/FOdb_puncta.csv'
FODB_S5_PATH = '../../data/raw_data/FOdb_SD5.csv'
with open_logfile(LOG_PATH):
s4 = pd.read_csv(FODB_S4_PATH)
s5 = pd.read_csv(FODB_S5_PATH)
retained_features = clean_s5(s5)
cleaned_s4 = clean_s4(s4, retained_features)
label_df = make_label_df(cleaned_s4)
embeddings = make_embeddings(cleaned_s4, retained_features)
# save the results
cleaned_s4.to_csv('cleaned_dataset_s4.csv', index=False)
log_update("\nSaved cleaned table S5 to cleaned_dataset_s4.csv")
label_df.to_csv('splits.csv', index=False)
log_update("\nSaved train-test splits with nucleus, cytoplasm, and formation labels to splits.csv")
with open('FOdb_physicochemical_embeddings.pkl','wb') as f:
pickle.dump(embeddings, f)
log_update("\nSaved physicochemical embeddings as a dictionary to FOdb_physicochemical_embeddings.pkl")
if __name__ == '__main__':
main()