|
from tensorflow.keras.preprocessing import text, sequence |
|
from tensorflow.keras.preprocessing.text import Tokenizer |
|
|
|
from torch.utils.data import DataLoader,Dataset |
|
import pandas as pd |
|
import seaborn as sns |
|
|
|
import torchvision |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
from torch import nn |
|
from torch import optim |
|
import torch.nn.functional as F |
|
from torchvision import datasets, transforms, models |
|
|
|
import torch.optim as optim |
|
from torch.optim.lr_scheduler import ExponentialLR, StepLR |
|
from functools import partial, wraps |
|
|
|
from sklearn.model_selection import train_test_split |
|
from sklearn.preprocessing import QuantileTransformer |
|
from sklearn.preprocessing import RobustScaler |
|
|
|
from matplotlib.ticker import MaxNLocator |
|
|
|
import torch |
|
|
|
import esm |
|
|
|
|
|
|
|
|
|
class RegressionDataset(Dataset): |
|
|
|
def __init__(self, X_data, y_data): |
|
self.X_data = X_data |
|
self.y_data = y_data |
|
|
|
def __getitem__(self, index): |
|
return self.X_data[index], self.y_data[index] |
|
|
|
def __len__ (self): |
|
return len(self.X_data) |
|
|
|
|
|
def pad_a_np_arr(x0,add_x,n_len): |
|
n0 = len(x0) |
|
|
|
x1 = x0.copy() |
|
if n0<n_len: |
|
for ii in range(n0,n_len,1): |
|
|
|
x1= np.append(x1, [add_x]) |
|
else: |
|
print('No padding is needed') |
|
|
|
return x1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_a_np_arr_esm(x0,add_x,n_len): |
|
|
|
|
|
x1 = x0.copy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
n0 = len(x1) |
|
|
|
if n0<n_len: |
|
for ii in range(n0,n_len,1): |
|
|
|
x1= np.append(x1, [add_x]) |
|
else: |
|
print('No padding is needed') |
|
|
|
return x1 |
|
|
|
|
|
|
|
|
|
|
|
def screen_dataset_MD( |
|
|
|
|
|
|
|
csv_file=None, |
|
pk_file =None, |
|
PKeys=None, |
|
CKeys=None, |
|
): |
|
|
|
|
|
store_path = PKeys['data_dir'] |
|
IF_SaveFig = CKeys['SlientRun'] |
|
min_AASeq_len = PKeys['min_AA_seq_len'] |
|
max_AASeq_len = PKeys['max_AA_seq_len'] |
|
max_used_Smo_F = PKeys['max_Force_cap'] |
|
|
|
|
|
if csv_file != None: |
|
|
|
print('=============================================') |
|
print('1. read in the csv file...') |
|
print('=============================================') |
|
arr_key = PKeys['arr_key'] |
|
|
|
df_raw = pd.read_csv(csv_file) |
|
print(df_raw.keys()) |
|
|
|
|
|
for this_key in arr_key: |
|
|
|
df_raw[this_key] = df_raw[this_key].apply(lambda x: np.array(list(map(float, x.split(" "))))) |
|
|
|
df_raw.rename(columns={"sample_FORCEpN_data":"sample_FORCE_data"}, inplace=True) |
|
print('Updated keys: \n', df_raw.keys()) |
|
|
|
elif pk_file != None: |
|
df_raw = pd.read_pickle(pk_file) |
|
print(df_raw.keys()) |
|
|
|
|
|
|
|
fig = plt.figure(figsize=(24,16),dpi=200) |
|
fig, ax0 = plt.subplots() |
|
for ii in range(len( df_raw )): |
|
if df_raw['seq_len'][ii]<=6400: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax0.plot( |
|
df_raw['sample_NormPullGap_data'][ii], |
|
|
|
df_raw['sample_FORCE_data'][ii], |
|
alpha=0.1, |
|
|
|
|
|
) |
|
ax0.scatter( |
|
df_raw['NPullGap_for_MaxSmoF'][ii], |
|
df_raw['Max_Smo_Force'][ii], |
|
) |
|
else: |
|
print(df_raw['pdb_id'][ii]) |
|
|
|
plt.xlabel('Normalized distance btw pulling ends') |
|
plt.ylabel('Force (pF)') |
|
outname = store_path+'CSV_0_SMD_sim_Dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
print('=============================================') |
|
print('2. screen the entries...') |
|
print('=============================================') |
|
|
|
df_isnull = pd.DataFrame( |
|
round( |
|
(df_raw.isnull().sum().sort_values(ascending=False)/df_raw.shape[0])*100, |
|
1 |
|
) |
|
).reset_index() |
|
df_isnull.style.format({'% of Missing Data': lambda x:'{:.1%}'.format(abs(x))}) |
|
cm = sns.light_palette("skyblue", as_cmap=True) |
|
df_isnull = df_isnull.style.background_gradient(cmap=cm) |
|
print('Check null...') |
|
print( df_isnull ) |
|
|
|
print('Working on a dataframe with useful keywords') |
|
protein_df = pd.DataFrame().assign( |
|
pdb_id=df_raw['pdb_id'], |
|
AA=df_raw['AA'], |
|
seq_len=df_raw['seq_len'], |
|
Max_Smo_Force=df_raw['Max_Smo_Force'], |
|
NPullGap_for_MaxSmoF=df_raw['NPullGap_for_MaxSmoF'], |
|
|
|
|
|
|
|
sample_FORCE_data=df_raw['sample_FORCE_data'], |
|
sample_NormPullGap_data=df_raw['sample_NormPullGap_data'], |
|
ini_gap=df_raw['ini_gap'], |
|
Fmax_pN_BSDB=df_raw['Fmax_pN_BSDB'], |
|
|
|
pull_data=df_raw['pull_data'], |
|
forc_data=df_raw['forc_data'], |
|
Int_Smo_ForcPull=df_raw['Int_Ene'], |
|
) |
|
|
|
|
|
|
|
print('a. screen using sequence length...') |
|
print('original sequences #: ', len(protein_df)) |
|
|
|
protein_df.drop( |
|
protein_df[protein_df['seq_len']>max_AASeq_len-2].index, |
|
inplace = True |
|
) |
|
protein_df.drop( |
|
protein_df[protein_df['seq_len'] <min_AASeq_len].index, |
|
inplace = True |
|
) |
|
protein_df=protein_df.reset_index(drop=True) |
|
print('used sequences #: ', len(protein_df)) |
|
|
|
print('b. screen using force values...') |
|
print('original sequences #: ', len(protein_df)) |
|
|
|
protein_df.drop( |
|
protein_df[protein_df['Max_Smo_Force']>max_used_Smo_F].index, |
|
inplace = True |
|
) |
|
|
|
|
|
|
|
|
|
protein_df=protein_df.reset_index(drop=True) |
|
print('afterwards, sequences #: ', len(protein_df)) |
|
|
|
|
|
|
|
|
|
fig = plt.figure() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sns.distplot( |
|
protein_df['seq_len'], |
|
bins=50,kde=False, |
|
rug=False,norm_hist=False) |
|
sns.distplot( |
|
df_raw['seq_len'], |
|
bins=50,kde=False, |
|
rug=False,norm_hist=False) |
|
|
|
plt.legend(['Selected','Full recrod']) |
|
plt.xlabel('AA length') |
|
outname = store_path+'CSV_1_AALen_Dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
fig = plt.figure() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sns.distplot(protein_df['Max_Smo_Force'],kde=True, rug=False,norm_hist=False) |
|
sns.distplot(df_raw['Max_Smo_Force'],kde=True, rug=False,norm_hist=False) |
|
|
|
plt.legend(['Selected','Full recrod']) |
|
plt.xlabel('Max. Force (pN) from MD') |
|
outname = store_path+'CSV_2_MaxSmoF_Dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
print('Check selected in SMD records...') |
|
|
|
fig = plt.figure(figsize=(24,16),dpi=200) |
|
fig, ax0 = plt.subplots() |
|
for ii in range(len( protein_df )): |
|
if protein_df['seq_len'][ii]<=6400: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax0.plot( |
|
protein_df['sample_NormPullGap_data'][ii], |
|
|
|
protein_df['sample_FORCE_data'][ii], |
|
alpha=0.1, |
|
|
|
|
|
) |
|
ax0.scatter( |
|
protein_df['NPullGap_for_MaxSmoF'][ii], |
|
protein_df['Max_Smo_Force'][ii], |
|
) |
|
else: |
|
print(protein_df['pdb_id'][ii]) |
|
|
|
plt.xlabel('Normalized distance btw pulling ends') |
|
plt.ylabel('Force (pF)') |
|
outname = store_path+'CSV_3_Screened_SMD_sim_Dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
|
|
fig = plt.figure() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sns.distplot(protein_df['Int_Smo_ForcPull'],kde=True, rug=False,norm_hist=False) |
|
sns.distplot(df_raw['Int_Ene'],kde=True, rug=False,norm_hist=False) |
|
|
|
plt.legend(['Selected','Full recrod']) |
|
plt.xlabel('Integrated energy (pN*Angstrom) from smoothed MD hist.') |
|
outname = store_path+'CSV_4_MaxSmoF_Dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
print('Done') |
|
|
|
|
|
return df_raw, protein_df |
|
|
|
|
|
def screen_dataset_MD_old( |
|
file_path, |
|
PKeys=None, |
|
CKeys=None, |
|
): |
|
|
|
arr_key = PKeys['arr_key'] |
|
store_path = PKeys['data_dir'] |
|
IF_SaveFig = CKeys['SlientRun'] |
|
min_AASeq_len = PKeys['min_AA_seq_len'] |
|
max_AASeq_len = PKeys['max_AA_seq_len'] |
|
max_used_Smo_F = PKeys['max_Force_cap'] |
|
|
|
|
|
|
|
print('=============================================') |
|
print('1. read in the csv file...') |
|
print('=============================================') |
|
df_raw = pd.read_csv(file_path) |
|
print(df_raw.keys()) |
|
|
|
|
|
for this_key in arr_key: |
|
|
|
df_raw[this_key] = df_raw[this_key].apply(lambda x: np.array(list(map(float, x.split(" "))))) |
|
|
|
fig = plt.figure(figsize=(24,16),dpi=200) |
|
fig, ax0 = plt.subplots() |
|
for ii in range(len( df_raw )): |
|
if df_raw['seq_len'][ii]<=6400: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax0.plot( |
|
df_raw['sample_NormPullGap_data'][ii], |
|
df_raw['sample_FORCEpN_data'][ii], |
|
alpha=0.1, |
|
|
|
|
|
) |
|
ax0.scatter( |
|
df_raw['NPullGap_for_MaxSmoF'][ii], |
|
df_raw['Max_Smo_Force'][ii], |
|
) |
|
else: |
|
print(df_raw['pdb_id'][ii]) |
|
|
|
plt.xlabel('Normalized distance btw pulling ends') |
|
plt.ylabel('Force (pF)') |
|
outname = store_path+'CSV_0_SMD_sim_Dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
print('=============================================') |
|
print('2. screen the entries...') |
|
print('=============================================') |
|
|
|
df_isnull = pd.DataFrame( |
|
round( |
|
(df_raw.isnull().sum().sort_values(ascending=False)/df_raw.shape[0])*100, |
|
1 |
|
) |
|
).reset_index() |
|
df_isnull.style.format({'% of Missing Data': lambda x:'{:.1%}'.format(abs(x))}) |
|
cm = sns.light_palette("skyblue", as_cmap=True) |
|
df_isnull = df_isnull.style.background_gradient(cmap=cm) |
|
print('Check null...') |
|
print( df_isnull ) |
|
|
|
print('Working on a dataframe with useful keywords') |
|
protein_df = pd.DataFrame().assign( |
|
pdb_id=df_raw['pdb_id'], |
|
AA=df_raw['AA'], |
|
seq_len=df_raw['seq_len'], |
|
Max_Smo_Force=df_raw['Max_Smo_Force'], |
|
NPullGap_for_MaxSmoF=df_raw['NPullGap_for_MaxSmoF'], |
|
sample_FORCEpN_data=df_raw['sample_FORCEpN_data'], |
|
sample_NormPullGap_data=df_raw['sample_NormPullGap_data'], |
|
ini_gap=df_raw['ini_gap'], |
|
Fmax_pN_BSDB=df_raw['Fmax_pN_BSDB'], |
|
|
|
pull_data=df_raw['pull_data'], |
|
forc_data=df_raw['forc_data'], |
|
Int_Smo_ForcPull=df_raw['Int_Ene'], |
|
) |
|
|
|
|
|
|
|
print('a. screen using sequence length...') |
|
print('original sequences #: ', len(protein_df)) |
|
|
|
protein_df.drop( |
|
protein_df[protein_df['seq_len']>max_AASeq_len-2].index, |
|
inplace = True |
|
) |
|
protein_df.drop( |
|
protein_df[protein_df['seq_len'] <min_AASeq_len].index, |
|
inplace = True |
|
) |
|
protein_df=protein_df.reset_index(drop=True) |
|
print('used sequences #: ', len(protein_df)) |
|
|
|
print('b. screen using force values...') |
|
print('original sequences #: ', len(protein_df)) |
|
|
|
protein_df.drop( |
|
protein_df[protein_df['Max_Smo_Force']>max_used_Smo_F].index, |
|
inplace = True |
|
) |
|
|
|
|
|
|
|
|
|
protein_df=protein_df.reset_index(drop=True) |
|
print('afterwards, sequences #: ', len(protein_df)) |
|
|
|
|
|
|
|
|
|
fig = plt.figure() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sns.distplot( |
|
protein_df['seq_len'], |
|
bins=50,kde=False, |
|
rug=False,norm_hist=False) |
|
sns.distplot( |
|
df_raw['seq_len'], |
|
bins=50,kde=False, |
|
rug=False,norm_hist=False) |
|
|
|
plt.legend(['Selected','Full recrod']) |
|
plt.xlabel('AA length') |
|
outname = store_path+'CSV_1_AALen_Dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
fig = plt.figure() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sns.distplot(protein_df['Max_Smo_Force'],kde=True, rug=False,norm_hist=False) |
|
sns.distplot(df_raw['Max_Smo_Force'],kde=True, rug=False,norm_hist=False) |
|
|
|
plt.legend(['Selected','Full recrod']) |
|
plt.xlabel('Max. Force (pN) from MD') |
|
outname = store_path+'CSV_2_MaxSmoF_Dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
print('Check selected in SMD records...') |
|
|
|
fig = plt.figure(figsize=(24,16),dpi=200) |
|
fig, ax0 = plt.subplots() |
|
for ii in range(len( protein_df )): |
|
if protein_df['seq_len'][ii]<=6400: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax0.plot( |
|
protein_df['sample_NormPullGap_data'][ii], |
|
protein_df['sample_FORCEpN_data'][ii], |
|
alpha=0.1, |
|
|
|
|
|
) |
|
ax0.scatter( |
|
protein_df['NPullGap_for_MaxSmoF'][ii], |
|
protein_df['Max_Smo_Force'][ii], |
|
) |
|
else: |
|
print(protein_df['pdb_id'][ii]) |
|
|
|
plt.xlabel('Normalized distance btw pulling ends') |
|
plt.ylabel('Force (pF)') |
|
outname = store_path+'CSV_3_Screened_SMD_sim_Dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
|
|
fig = plt.figure() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sns.distplot(protein_df['Int_Smo_ForcPull'],kde=True, rug=False,norm_hist=False) |
|
sns.distplot(df_raw['Int_Ene'],kde=True, rug=False,norm_hist=False) |
|
|
|
plt.legend(['Selected','Full recrod']) |
|
plt.xlabel('Integrated energy (pN*Angstrom) from smoothed MD hist.') |
|
outname = store_path+'CSV_4_MaxSmoF_Dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
print('Done') |
|
|
|
|
|
return df_raw, protein_df |
|
|
|
|
|
|
|
|
|
|
|
def load_data_set_from_df_SMD( |
|
protein_df, |
|
PKeys=None, |
|
CKeys=None, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
store_path = PKeys['data_dir'] |
|
IF_SaveFig = CKeys['SlientRun'] |
|
|
|
X_Key = PKeys['X_Key'] |
|
|
|
max_AA_len = PKeys['max_AA_seq_len'] |
|
norm_fac_force = PKeys['Xnormfac'] |
|
AA_seq_normfac = PKeys['ynormfac'] |
|
tokenizer_X = PKeys['tokenizer_X'] |
|
tokenizer_y = PKeys['tokenizer_y'] |
|
batch_size = PKeys['batch_size'] |
|
TestSet_ratio = PKeys['testset_ratio'] |
|
|
|
maxdata=PKeys['maxdata'] |
|
|
|
|
|
print("======================================================") |
|
print("1. work on X data") |
|
print("======================================================") |
|
X1 = [] |
|
X2 = [] |
|
for i in range(len(protein_df)): |
|
X1.append( [ protein_df['Max_Smo_Force'][i] ] ) |
|
X2.append( |
|
pad_a_np_arr(protein_df['sample_FORCEpN_data'][i],0,max_AA_len) |
|
) |
|
X1=np.array(X1) |
|
X2=np.array(X2) |
|
print('Max_F shape: ', X1.shape) |
|
print('SMD_F_Path shape', X2.shape) |
|
|
|
if X_Key=='Max_Smo_Force': |
|
X = X1.copy() |
|
else: |
|
X = X2.copy() |
|
print('Use '+X_Key) |
|
|
|
|
|
print('Normalized factor for the force: ', norm_fac_force) |
|
|
|
|
|
|
|
X = X/norm_fac_force |
|
|
|
|
|
fig = plt.figure() |
|
sns.distplot( |
|
X1[:,0],bins=50,kde=False, |
|
rug=False,norm_hist=False, |
|
axlabel='Normalized Fmax') |
|
|
|
plt.ylabel('Counts') |
|
outname = store_path+'CSV_4_AfterNorm_SMD_Fmax.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
print("======================================================") |
|
print("2. work on Y data") |
|
print("======================================================") |
|
|
|
|
|
seqs = protein_df.AA.values |
|
|
|
if tokenizer_y==None: |
|
tokenizer_y = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) |
|
tokenizer_y.fit_on_texts(seqs) |
|
|
|
|
|
y_data = tokenizer_y.texts_to_sequences(seqs) |
|
|
|
y_data= sequence.pad_sequences( |
|
y_data, maxlen=max_AA_len, |
|
padding='post', truncating='post') |
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), |
|
x='AA code', bins=np.array([i-0.5 for i in range(0,20+3,1)]) |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
fig_handle.set_xlim(-1, 21) |
|
|
|
outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
print ("#################################") |
|
print ("DICTIONARY y_data") |
|
dictt=tokenizer_y.get_config() |
|
print (dictt) |
|
num_words = len(tokenizer_y.word_index) + 1 |
|
print ("################## y max token: ",num_words ) |
|
|
|
|
|
print ("TEST REVERSE: ") |
|
y_data_reversed=tokenizer_y.sequences_to_texts (y_data) |
|
|
|
for iii in range (len(y_data_reversed)): |
|
y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
|
|
|
print ("Element 0", y_data_reversed[0]) |
|
print ("Number of y samples",len (y_data_reversed) ) |
|
|
|
for iii in [0,2,6]: |
|
print("Ori and REVERSED SEQ: ", iii) |
|
print(seqs[iii]) |
|
print(y_data_reversed[iii]) |
|
|
|
|
|
|
|
|
|
print ("Len 0 as example: ", len (y_data_reversed[0]) ) |
|
print ("CHeck ori: ", len (seqs[0]) ) |
|
print ("Len 2 as example: ", len (y_data_reversed[2]) ) |
|
print ("CHeck ori: ", len (seqs[2]) ) |
|
|
|
if maxdata<y_data.shape[0]: |
|
print ('select subset...', maxdata ) |
|
|
|
|
|
X=X[:maxdata] |
|
y_data=y_data[:maxdata] |
|
print ("new shapes (X, y_data): ", X.shape, y_data.shape) |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
X, y_data, |
|
test_size=TestSet_ratio, |
|
random_state=235) |
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = RegressionDataset( |
|
torch.from_numpy(X_train).float(), |
|
torch.from_numpy(y_train).float()/AA_seq_normfac |
|
) |
|
|
|
test_dataset = RegressionDataset( |
|
torch.from_numpy(X_test).float(), |
|
torch.from_numpy(y_test).float()/AA_seq_normfac |
|
) |
|
|
|
train_loader = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True) |
|
train_loader_noshuffle = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=False) |
|
test_loader = DataLoader( |
|
dataset=test_dataset, |
|
batch_size=batch_size) |
|
|
|
|
|
return train_loader, train_loader_noshuffle, \ |
|
test_loader, tokenizer_y, tokenizer_X |
|
|
|
|
|
|
|
|
|
def load_data_set_from_df_SMD_pLM( |
|
protein_df, |
|
PKeys=None, |
|
CKeys=None, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
store_path = PKeys['data_dir'] |
|
IF_SaveFig = CKeys['SlientRun'] |
|
|
|
X_Key = PKeys['X_Key'] |
|
|
|
max_AA_len = PKeys['max_AA_seq_len'] |
|
norm_fac_force = PKeys['Xnormfac'] |
|
AA_seq_normfac = PKeys['ynormfac'] |
|
tokenizer_X = PKeys['tokenizer_X'] |
|
tokenizer_y = PKeys['tokenizer_y'] |
|
batch_size = PKeys['batch_size'] |
|
TestSet_ratio = PKeys['testset_ratio'] |
|
maxdata=PKeys['maxdata'] |
|
|
|
|
|
|
|
print("======================================================") |
|
print("1. work on X data: Normalized ForcePath") |
|
print("======================================================") |
|
X1 = [] |
|
X2 = [] |
|
|
|
for i in range(len(protein_df)): |
|
X1.append( [ protein_df['Max_Smo_Force'][i] ] ) |
|
|
|
|
|
|
|
X2.append( |
|
pad_a_np_arr_esm(protein_df['sample_FORCE_data'][i],0,max_AA_len) |
|
) |
|
X1=np.array(X1) |
|
X2=np.array(X2) |
|
print('Max_F shape: ', X1.shape) |
|
print('SMD_F_Path shape', X2.shape) |
|
|
|
if X_Key=='Max_Smo_Force': |
|
X = X1.copy() |
|
else: |
|
X = X2.copy() |
|
print('Use '+X_Key) |
|
|
|
|
|
print('Normalized factor for the force: ', norm_fac_force) |
|
|
|
|
|
|
|
X = X/norm_fac_force |
|
|
|
print("tokenizer_X=None") |
|
|
|
fig = plt.figure() |
|
sns.distplot( |
|
X1[:,0],bins=50,kde=False, |
|
rug=False,norm_hist=False, |
|
axlabel='Fmax') |
|
|
|
plt.ylabel('Counts') |
|
outname = store_path+'CSV_4_AfterNorm_SMD_Fmax.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
print("======================================================") |
|
print("2. work on Y data: AA Sequence") |
|
print("======================================================") |
|
|
|
|
|
seqs = protein_df.AA.values |
|
|
|
|
|
print("pLM model: ", PKeys['ESM-2_Model']) |
|
|
|
if PKeys['ESM-2_Model']=='esm2_t33_650M_UR50D': |
|
|
|
|
|
esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
elif PKeys['ESM-2_Model']=='esm2_t12_35M_UR50D': |
|
|
|
esm_model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
elif PKeys['ESM-2_Model']=='esm2_t36_3B_UR50D': |
|
|
|
esm_model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
elif PKeys['ESM-2_Model']=='esm2_t30_150M_UR50D': |
|
|
|
esm_model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
else: |
|
print("protein language model is not defined.") |
|
|
|
|
|
print("esm_alphabet.use_msa: ", esm_alphabet.use_msa) |
|
print("# of tokens in AA alphabet: ", len_toks) |
|
|
|
esm_batch_converter = esm_alphabet.get_batch_converter( |
|
truncation_seq_length=PKeys['max_AA_seq_len']-2 |
|
) |
|
esm_model.eval() |
|
|
|
|
|
seqs_ext=[] |
|
for i in range(len(seqs)): |
|
seqs_ext.append( |
|
(" ", seqs[i]) |
|
) |
|
|
|
_, y_strs, y_data = esm_batch_converter(seqs_ext) |
|
y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) |
|
|
|
print ("y_data.dim: ", y_data.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), |
|
x='AA code', |
|
bins=np.array([i-0.5 for i in range(0,33+3,1)]), |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
fig_handle.set_xlim(-1, 33+1) |
|
|
|
outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print ("#################################") |
|
print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) |
|
print ("################## y max token: ",len_toks ) |
|
|
|
|
|
|
|
print ("TEST REVERSE: ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) |
|
|
|
|
|
print ("Element 0", y_data_reversed[0]) |
|
print ("Number of y samples",len (y_data_reversed) ) |
|
|
|
for iii in [0,2,6]: |
|
print("Ori and REVERSED SEQ: ", iii) |
|
print(seqs[iii]) |
|
print(y_data_reversed[iii]) |
|
|
|
|
|
|
|
|
|
print ("Len 0 as example: ", len (y_data_reversed[0]) ) |
|
print ("CHeck ori: ", len (seqs[0]) ) |
|
print ("Len 2 as example: ", len (y_data_reversed[2]) ) |
|
print ("CHeck ori: ", len (seqs[2]) ) |
|
|
|
if maxdata<y_data.shape[0]: |
|
print ('select subset...', maxdata ) |
|
|
|
|
|
X=X[:maxdata] |
|
y_data=y_data[:maxdata] |
|
print ("new shapes (X, y_data): ", X.shape, y_data.shape) |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
X, y_data, |
|
test_size=TestSet_ratio, |
|
random_state=235) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = RegressionDataset( |
|
torch.from_numpy(X_train).float(), |
|
y_train, |
|
) |
|
|
|
test_dataset = RegressionDataset( |
|
torch.from_numpy(X_test).float(), |
|
y_test, |
|
) |
|
|
|
|
|
train_loader = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True) |
|
train_loader_noshuffle = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=False) |
|
test_loader = DataLoader( |
|
dataset=test_dataset, |
|
batch_size=batch_size) |
|
|
|
|
|
return train_loader, train_loader_noshuffle, \ |
|
test_loader, tokenizer_y, tokenizer_X |
|
|
|
|
|
|
|
|
|
def load_data_set_from_df_from_pLM_to_SMD( |
|
protein_df, |
|
PKeys=None, |
|
CKeys=None, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
store_path = PKeys['data_dir'] |
|
IF_SaveFig = CKeys['SlientRun'] |
|
|
|
X_Key = PKeys['X_Key'] |
|
|
|
max_AA_len = PKeys['max_AA_seq_len'] |
|
|
|
|
|
|
|
|
|
norm_fac_force = PKeys['ynormfac'] |
|
AA_seq_normfac = PKeys['Xnormfac'] |
|
|
|
tokenizer_X = PKeys['tokenizer_X'] |
|
tokenizer_y = PKeys['tokenizer_y'] |
|
batch_size = PKeys['batch_size'] |
|
TestSet_ratio = PKeys['testset_ratio'] |
|
maxdata=PKeys['maxdata'] |
|
|
|
|
|
|
|
print("======================================================") |
|
|
|
print("1. work on Y data: Normalized ForcePath") |
|
print("======================================================") |
|
X1 = [] |
|
X2 = [] |
|
|
|
for i in range(len(protein_df)): |
|
X1.append( [ protein_df['Max_Smo_Force'][i] ] ) |
|
|
|
|
|
|
|
X2.append( |
|
pad_a_np_arr_esm(protein_df['sample_FORCE_data'][i],0,max_AA_len) |
|
) |
|
X1=np.array(X1) |
|
X2=np.array(X2) |
|
print('Max_F shape: ', X1.shape) |
|
print('SMD_F_Path shape', X2.shape) |
|
|
|
if X_Key=='Max_Smo_Force': |
|
X = X1.copy() |
|
else: |
|
X = X2.copy() |
|
print('Use '+X_Key) |
|
|
|
|
|
print('Normalized factor for the force: ', norm_fac_force) |
|
|
|
|
|
|
|
X = X/norm_fac_force |
|
|
|
print("tokenizer_X=None") |
|
|
|
fig = plt.figure() |
|
sns.distplot( |
|
X1[:,0],bins=50,kde=False, |
|
rug=False,norm_hist=False, |
|
axlabel='Fmax') |
|
|
|
plt.ylabel('Counts') |
|
outname = store_path+'CSV_4_AfterNorm_SMD_Fmax.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
print("======================================================") |
|
print("2. work on Y data: AA Sequence") |
|
print("======================================================") |
|
|
|
|
|
seqs = protein_df.AA.values |
|
if PKeys['ESM-2_Model']=='trivial': |
|
print("Plain tokenizer of AA sequence is used...") |
|
|
|
if tokenizer_y==None: |
|
tokenizer_y = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) |
|
tokenizer_y.fit_on_texts(seqs) |
|
|
|
|
|
y_data = tokenizer_y.texts_to_sequences(seqs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_data= sequence.pad_sequences( |
|
y_data, maxlen=max_AA_len-1, |
|
padding='post', truncating='post', |
|
value=0.0, |
|
) |
|
|
|
y_data= sequence.pad_sequences( |
|
y_data, maxlen=max_AA_len, |
|
padding='pre', truncating='pre', |
|
value=0.0, |
|
) |
|
|
|
len_toks = len(tokenizer_y.word_index) + 1 |
|
|
|
else: |
|
|
|
|
|
print("pLM model: ", PKeys['ESM-2_Model']) |
|
|
|
if PKeys['ESM-2_Model']=='esm2_t33_650M_UR50D': |
|
|
|
|
|
esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
elif PKeys['ESM-2_Model']=='esm2_t12_35M_UR50D': |
|
|
|
esm_model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
elif PKeys['ESM-2_Model']=='esm2_t36_3B_UR50D': |
|
|
|
esm_model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
elif PKeys['ESM-2_Model']=='esm2_t30_150M_UR50D': |
|
|
|
esm_model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
else: |
|
print("protein language model is not defined.") |
|
|
|
|
|
print("esm_alphabet.use_msa: ", esm_alphabet.use_msa) |
|
print("# of tokens in AA alphabet: ", len_toks) |
|
|
|
esm_batch_converter = esm_alphabet.get_batch_converter( |
|
truncation_seq_length=PKeys['max_AA_seq_len']-2 |
|
) |
|
esm_model.eval() |
|
|
|
|
|
seqs_ext=[] |
|
for i in range(len(seqs)): |
|
seqs_ext.append( |
|
(" ", seqs[i]) |
|
) |
|
|
|
_, y_strs, y_data = esm_batch_converter(seqs_ext) |
|
y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) |
|
|
|
|
|
|
|
|
|
current_seq_len = y_data.shape[1] |
|
print("current seq batch len: ", current_seq_len) |
|
missing_num_pad = PKeys['max_AA_seq_len']-current_seq_len |
|
if missing_num_pad>0: |
|
print("extra padding is added to match the target seq input length...") |
|
|
|
y_data = F.pad( |
|
y_data, |
|
(0, missing_num_pad), |
|
"constant", esm_alphabet.padding_idx |
|
) |
|
else: |
|
print("No extra padding is needed") |
|
|
|
|
|
|
|
|
|
print ("y_data.dim: ", y_data.shape) |
|
print ("y_data.type: ", y_data.type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), |
|
x='AA code', |
|
bins=np.array([i-0.5 for i in range(0,33+3,1)]), |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
fig_handle.set_xlim(-1, 33+1) |
|
|
|
outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print ("#################################") |
|
print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) |
|
print ("################## y max token: ",len_toks ) |
|
|
|
|
|
|
|
print ("TEST REVERSE: ") |
|
|
|
if PKeys['ESM-2_Model']=='trivial': |
|
|
|
y_data_reversed=tokenizer_y.sequences_to_texts (y_data) |
|
for iii in range (len(y_data_reversed)): |
|
y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
|
else: |
|
|
|
|
|
|
|
y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) |
|
|
|
|
|
print ("Element 0", y_data_reversed[0]) |
|
print ("Number of y samples",len (y_data_reversed) ) |
|
|
|
for iii in [0,2,6]: |
|
print("Ori and REVERSED SEQ: ", iii) |
|
print(seqs[iii]) |
|
print(y_data_reversed[iii]) |
|
|
|
|
|
|
|
|
|
print ("Len 0 as example: ", len (y_data_reversed[0]) ) |
|
print ("CHeck ori: ", len (seqs[0]) ) |
|
print ("Len 2 as example: ", len (y_data_reversed[2]) ) |
|
print ("CHeck ori: ", len (seqs[2]) ) |
|
|
|
if maxdata<y_data.shape[0]: |
|
print ('select subset...', maxdata ) |
|
|
|
|
|
X=X[:maxdata] |
|
y_data=y_data[:maxdata] |
|
print ("new shapes (X, y_data): ", X.shape, y_data.shape) |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
X, y_data, |
|
test_size=TestSet_ratio, |
|
random_state=235) |
|
|
|
|
|
if PKeys['ESM-2_Model']=='trivial': |
|
|
|
train_dataset = RegressionDataset( |
|
torch.from_numpy(y_train).float()/AA_seq_normfac, |
|
torch.from_numpy(X_train).float(), |
|
) |
|
|
|
test_dataset = RegressionDataset( |
|
torch.from_numpy(y_test).float()/AA_seq_normfac, |
|
torch.from_numpy(X_test).float(), |
|
) |
|
else: |
|
|
|
train_dataset = RegressionDataset( |
|
y_train, |
|
torch.from_numpy(X_train).float(), |
|
) |
|
|
|
test_dataset = RegressionDataset( |
|
y_test, |
|
torch.from_numpy(X_test).float(), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True) |
|
train_loader_noshuffle = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=False) |
|
test_loader = DataLoader( |
|
dataset=test_dataset, |
|
batch_size=batch_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return train_loader, train_loader_noshuffle, \ |
|
test_loader, tokenizer_X, tokenizer_y |
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_data_set_SS_InSeqToOuSeq( |
|
file_path, |
|
PKeys=None, |
|
CKeys=None, |
|
): |
|
|
|
min_AASeq_len = PKeys['min_AA_seq_len'] |
|
max_AASeq_len = PKeys['max_AA_seq_len'] |
|
batch_size = PKeys['batch_size'] |
|
TestSet_ratio = PKeys['testset_ratio'] |
|
Xnormfac = PKeys['Xnormfac'] |
|
ynormfac = PKeys['ynormfac'] |
|
tokenizer_X = PKeys['tokenizer_X'] |
|
tokenizer_y = PKeys['tokenizer_y'] |
|
maxdata=PKeys['maxdata'] |
|
|
|
store_path = PKeys['data_dir'] |
|
IF_SaveFig = CKeys['SlientRun'] |
|
|
|
|
|
|
|
protein_df=pd.read_csv(file_path) |
|
protein_df.describe() |
|
|
|
df_isnull = pd.DataFrame( |
|
round( |
|
(protein_df.isnull().sum().sort_values(ascending=False)/protein_df.shape[0])*100, |
|
1 |
|
) |
|
).reset_index() |
|
df_isnull.columns = ['Columns', '% of Missing Data'] |
|
df_isnull.style.format({'% of Missing Data': lambda x:'{:.1%}'.format(abs(x))}) |
|
cm = sns.light_palette("skyblue", as_cmap=True) |
|
df_isnull = df_isnull.style.background_gradient(cmap=cm) |
|
df_isnull |
|
|
|
|
|
protein_df['Seq_Len'] = protein_df.apply(lambda x: len(x['Seq']), axis=1) |
|
|
|
protein_df.drop(protein_df[protein_df['Seq_Len'] >max_AASeq_len-2].index, inplace = True) |
|
protein_df.drop(protein_df[protein_df['Seq_Len'] <min_AASeq_len].index, inplace = True) |
|
protein_df=protein_df.reset_index(drop=True) |
|
|
|
seqs = protein_df.Sequence.values |
|
|
|
test_seqs = seqs[:1] |
|
|
|
lengths = [len(s) for s in seqs] |
|
|
|
print('After the screening using seq_len: ') |
|
print(protein_df.shape) |
|
print(protein_df.head(6)) |
|
|
|
min_length_measured = min (lengths) |
|
max_length_measured = max (lengths) |
|
|
|
fig_handle =sns.distplot( |
|
lengths, bins=25, |
|
kde=False, rug=False, |
|
norm_hist=True, |
|
axlabel='Length') |
|
outname=store_path+'0_DataSet_AA_Len_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
X_v = protein_df.Secstructure.values |
|
if tokenizer_X==None: |
|
tokenizer_X = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) |
|
tokenizer_X.fit_on_texts(X_v) |
|
|
|
X = tokenizer_X.texts_to_sequences(X_v) |
|
X = sequence.pad_sequences( |
|
X, maxlen=max_AASeq_len, |
|
padding='post', truncating='post') |
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'SecStr code': np.array(X).flatten()}), |
|
x='SecStr code', bins=np.array([i-0.5 for i in range(0,8+3,1)]) |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
fig_handle.set_xlim(-1, 9) |
|
|
|
outname=store_path+'1_DataSet_SecStrCode_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
print ("#################################") |
|
print ("DICTIONARY X") |
|
dictt=tokenizer_X.get_config() |
|
print (dictt) |
|
num_wordsX = len(tokenizer_X.word_index) + 1 |
|
|
|
print ("################## X max token: ",num_wordsX ) |
|
|
|
X=np.array(X) |
|
print ("sample X data", X[0]) |
|
|
|
|
|
|
|
seqs = protein_df.Sequence.values |
|
|
|
|
|
if tokenizer_y==None: |
|
tokenizer_y = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) |
|
tokenizer_y.fit_on_texts(seqs) |
|
|
|
y_data = tokenizer_y.texts_to_sequences(seqs) |
|
y_data= sequence.pad_sequences( |
|
y_data, maxlen=max_AASeq_len, |
|
padding='post', truncating='post') |
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), |
|
x='AA code', bins=np.array([i-0.5 for i in range(0,20+3,1)]) |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
fig_handle.set_xlim(-1, 21) |
|
|
|
outname=store_path+'2_DataSet_AACode_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
print ("#################################") |
|
print ("DICTIONARY y_data") |
|
dictt=tokenizer_y.get_config() |
|
print (dictt) |
|
num_words = len(tokenizer_y.word_index) + 1 |
|
print ("################## y max token: ",num_words ) |
|
|
|
|
|
|
|
print ("y data shape: ", y_data.shape) |
|
print ("X data shape: ", X.shape) |
|
|
|
|
|
print ("TEST REVERSE: ") |
|
|
|
y_data_reversed=tokenizer_y.sequences_to_texts (y_data) |
|
|
|
for iii in range (len(y_data_reversed)): |
|
y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
|
|
|
print ("Element 0", y_data_reversed[0]) |
|
|
|
print ("Number of y samples",len (y_data_reversed) ) |
|
|
|
for iii in [0,2,6]: |
|
print("Ori and REVERSED SEQ: ", iii) |
|
print(seqs[iii]) |
|
print(y_data_reversed[iii]) |
|
|
|
|
|
|
|
|
|
print ("Len 0 as example: ", len (y_data_reversed[0]) ) |
|
print ("CHeck ori: ", len (seqs[0]) ) |
|
print ("Len 2 as example: ", len (y_data_reversed[2]) ) |
|
print ("CHeck ori: ", len (seqs[2]) ) |
|
|
|
|
|
if maxdata<y_data.shape[0]: |
|
print ('select subset...', maxdata ) |
|
X=X[:maxdata] |
|
y_data=y_data[:maxdata] |
|
print ("new shapes: ", X.shape, y_data.shape) |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
X, y_data, |
|
test_size=TestSet_ratio, |
|
random_state=235) |
|
|
|
|
|
train_dataset = RegressionDataset( |
|
torch.from_numpy(X_train).float()/Xnormfac, |
|
torch.from_numpy(y_train).float()/ynormfac) |
|
|
|
fig_handle = sns.distplot( |
|
torch.from_numpy(y_train.flatten()), |
|
bins=42,kde=False, |
|
rug=False,norm_hist=False,axlabel='y labels') |
|
fig = fig_handle.get_figure() |
|
outname=store_path+'3_DataSet_Norm_YTrain_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
torch.from_numpy(X_train.flatten()), |
|
bins=25,kde=False, |
|
rug=False,norm_hist=False,axlabel='X labels') |
|
fig = fig_handle.get_figure() |
|
outname=store_path+'3_DataSet_Norm_XTrain_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
test_dataset = RegressionDataset( |
|
torch.from_numpy(X_test).float()/Xnormfac, |
|
torch.from_numpy(y_test).float()/ynormfac) |
|
|
|
train_loader = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True) |
|
train_loader_noshuffle = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=False) |
|
test_loader = DataLoader( |
|
dataset=test_dataset, |
|
batch_size=batch_size) |
|
|
|
return train_loader, train_loader_noshuffle, test_loader, tokenizer_y , tokenizer_X |
|
|
|
|
|
|
|
|
|
from PD_pLMProbXDiff.UtilityPack import decode_one_ems_token_rec,decode_many_ems_token_rec |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_data_set_SS_InSeqToOuSeq_pLM( |
|
file_path, |
|
PKeys=None, |
|
CKeys=None, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
min_AASeq_len = PKeys['min_AA_seq_len'] |
|
max_AASeq_len = PKeys['max_AA_seq_len'] |
|
batch_size = PKeys['batch_size'] |
|
TestSet_ratio = PKeys['testset_ratio'] |
|
Xnormfac = PKeys['Xnormfac'] |
|
ynormfac = PKeys['ynormfac'] |
|
tokenizer_X = PKeys['tokenizer_X'] |
|
tokenizer_y = PKeys['tokenizer_y'] |
|
maxdata=PKeys['maxdata'] |
|
|
|
store_path = PKeys['data_dir'] |
|
IF_SaveFig = CKeys['SlientRun'] |
|
|
|
|
|
|
|
protein_df=pd.read_csv(file_path) |
|
protein_df.describe() |
|
|
|
df_isnull = pd.DataFrame( |
|
round( |
|
(protein_df.isnull().sum().sort_values(ascending=False)/protein_df.shape[0])*100, |
|
1 |
|
) |
|
).reset_index() |
|
df_isnull.columns = ['Columns', '% of Missing Data'] |
|
df_isnull.style.format({'% of Missing Data': lambda x:'{:.1%}'.format(abs(x))}) |
|
cm = sns.light_palette("skyblue", as_cmap=True) |
|
df_isnull = df_isnull.style.background_gradient(cmap=cm) |
|
df_isnull |
|
|
|
|
|
protein_df['Seq_Len'] = protein_df.apply(lambda x: len(x['Seq']), axis=1) |
|
|
|
protein_df.drop(protein_df[protein_df['Seq_Len'] >max_AASeq_len-2].index, inplace = True) |
|
protein_df.drop(protein_df[protein_df['Seq_Len'] <min_AASeq_len].index, inplace = True) |
|
protein_df=protein_df.reset_index(drop=True) |
|
|
|
seqs = protein_df.Sequence.values |
|
|
|
test_seqs = seqs[:1] |
|
|
|
lengths = [len(s) for s in seqs] |
|
|
|
print('After the screening using seq_len: ') |
|
print(protein_df.shape) |
|
print(protein_df.head(6)) |
|
|
|
min_length_measured = min (lengths) |
|
max_length_measured = max (lengths) |
|
|
|
fig_handle =sns.distplot( |
|
lengths, bins=25, |
|
kde=False, rug=False, |
|
norm_hist=True, |
|
axlabel='Length') |
|
outname=store_path+'0_DataSet_AA_Len_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
X_v = protein_df.Secstructure.values |
|
if tokenizer_X==None: |
|
tokenizer_X = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) |
|
tokenizer_X.fit_on_texts(X_v) |
|
|
|
X = tokenizer_X.texts_to_sequences(X_v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
X1=[] |
|
for this_X in X: |
|
X1.append( |
|
[0]+this_X |
|
) |
|
|
|
print(X1[0]) |
|
|
|
X = sequence.pad_sequences( |
|
X1, maxlen=max_AASeq_len, |
|
padding='post', truncating='post') |
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'SecStr code': np.array(X).flatten()}), |
|
x='SecStr code', bins=np.array([i-0.5 for i in range(0,8+3,1)]) |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
fig_handle.set_xlim(-1, 9) |
|
|
|
outname=store_path+'1_DataSet_SecStrCode_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
print ("#################################") |
|
print ("DICTIONARY X") |
|
dictt=tokenizer_X.get_config() |
|
print (dictt) |
|
num_wordsX = len(tokenizer_X.word_index) + 1 |
|
|
|
print ("################## X max token: ",num_wordsX ) |
|
|
|
X=np.array(X) |
|
print ("sample X data", X[0]) |
|
|
|
|
|
|
|
seqs = protein_df.Sequence.values |
|
|
|
if PKeys['ESM-2_Model']=='esm2_t33_650M_UR50D': |
|
|
|
esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
elif PKeys['ESM-2_Model']=='esm2_t12_35M_UR50D': |
|
esm_model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
else: |
|
pass |
|
|
|
|
|
print("esm_alphabet.use_msa: ", esm_alphabet.use_msa) |
|
print("# of tokens in AA alphabet: ", len_toks) |
|
|
|
esm_batch_converter = esm_alphabet.get_batch_converter( |
|
truncation_seq_length=PKeys['max_AA_seq_len']-2 |
|
) |
|
esm_model.eval() |
|
|
|
|
|
seqs_ext=[] |
|
for i in range(len(seqs)): |
|
seqs_ext.append( |
|
(" ", seqs[i]) |
|
) |
|
|
|
_, y_strs, y_data = esm_batch_converter(seqs_ext) |
|
y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'AA token (esm)': np.array(y_data).flatten()}), |
|
x='AA token (esm)', bins=np.array([i-0.5 for i in range(0,33+3,1)]) |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
|
|
fig_handle.set_xlim(-1, 33+1) |
|
|
|
outname=store_path+'2_DataSet_AACode_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print ("#################################") |
|
print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) |
|
print ("################## y max token: ",len_toks ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print ("y data shape: ", y_data.shape) |
|
print ("X data shape: ", X.shape) |
|
|
|
|
|
print ("TEST REVERSE: ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) |
|
|
|
print ("Element 0", y_data_reversed[0]) |
|
|
|
print ("Number of y samples",len (y_data_reversed) ) |
|
|
|
for iii in [0,2,6]: |
|
print("Ori and REVERSED SEQ: ", iii) |
|
print(seqs[iii]) |
|
print(y_data_reversed[iii]) |
|
|
|
|
|
|
|
|
|
print ("Len 0 as example: ", len (y_data_reversed[0]) ) |
|
print ("CHeck ori: ", len (seqs[0]) ) |
|
print ("Len 2 as example: ", len (y_data_reversed[2]) ) |
|
print ("CHeck ori: ", len (seqs[2]) ) |
|
|
|
|
|
if maxdata<y_data.shape[0]: |
|
print ('select subset...', maxdata ) |
|
X=X[:maxdata] |
|
y_data=y_data[:maxdata] |
|
print ("new shapes: ", X.shape, y_data.shape) |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
X, y_data, |
|
test_size=TestSet_ratio, |
|
random_state=235) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = RegressionDataset( |
|
torch.from_numpy(X_train).float()/Xnormfac, |
|
y_train) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'AA token (esm)': np.array(y_train).flatten()}), |
|
x='AA token (esm)', bins=np.array([i-0.5 for i in range(0,33+3,1)]) |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
|
|
fig_handle.set_xlim(-1, 33+1) |
|
outname=store_path+'3_DataSet_Norm_YTrain_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'SecStr code': np.array(X_train).flatten()}), |
|
x='SecStr code', bins=np.array([i-0.5 for i in range(0,8+3,1)]) |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
fig_handle.set_xlim(-1, 9) |
|
outname=store_path+'3_DataSet_Norm_XTrain_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_dataset = RegressionDataset( |
|
torch.from_numpy(X_test).float()/Xnormfac, |
|
y_test) |
|
|
|
train_loader = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True) |
|
train_loader_noshuffle = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=False) |
|
test_loader = DataLoader( |
|
dataset=test_dataset, |
|
batch_size=batch_size) |
|
|
|
return train_loader, train_loader_noshuffle, test_loader, tokenizer_y , tokenizer_X |
|
|
|
|
|
|
|
|
|
|
|
def load_data_set_seq2seq_SecStr_ModelA ( |
|
file_path, |
|
PKeys=None, |
|
CKeys=None, |
|
): |
|
|
|
min_length=PKeys['min_AA_seq_len'] |
|
max_length=PKeys['max_AA_seq_len'] |
|
batch_size_=PKeys['batch_size'] |
|
|
|
maxdata=PKeys['maxdata'] |
|
tokenizer_y=PKeys['tokenizer_y'] |
|
split=PKeys['testset_ratio'] |
|
ynormfac=PKeys['ynormfac'] |
|
|
|
Xnormfac=PKeys['Xnormfac'] |
|
tokenizer_X=PKeys['tokenizer_X'] |
|
|
|
IF_SaveFig = CKeys['SlientRun'] |
|
save_dir = PKeys['data_dir'] |
|
|
|
|
|
|
|
|
|
protein_df=pd.read_csv(file_path) |
|
protein_df |
|
|
|
protein_df.describe() |
|
|
|
df_isnull = pd.DataFrame(round((protein_df.isnull().sum().sort_values(ascending=False)/protein_df.shape[0])*100,1)).reset_index() |
|
df_isnull.columns = ['Columns', '% of Missing Data'] |
|
df_isnull.style.format({'% of Missing Data': lambda x:'{:.1%}'.format(abs(x))}) |
|
cm = sns.light_palette("skyblue", as_cmap=True) |
|
df_isnull = df_isnull.style.background_gradient(cmap=cm) |
|
df_isnull |
|
|
|
|
|
protein_df['Seq_Len'] = protein_df.apply(lambda x: len(x['Seq']), axis=1) |
|
|
|
print("Screen the sequence length...") |
|
protein_df.drop(protein_df[protein_df['Seq_Len'] >max_length-2].index, inplace = True) |
|
protein_df.drop(protein_df[protein_df['Seq_Len'] <min_length].index, inplace = True) |
|
|
|
print("# of data point: ", protein_df.shape) |
|
print(protein_df.head(6)) |
|
protein_df=protein_df.reset_index(drop=True) |
|
|
|
seqs = protein_df.Sequence.values |
|
|
|
test_seqs = seqs[:1] |
|
|
|
lengths = [len(s) for s in seqs] |
|
|
|
print(protein_df.shape) |
|
print(protein_df.head(6)) |
|
|
|
min_length_measured = min (lengths) |
|
max_length_measured = max (lengths) |
|
|
|
print ("Measured seq length: min and max") |
|
print (min_length_measured,max_length_measured) |
|
|
|
|
|
fig_handle = sns.distplot( |
|
lengths, |
|
bins=50, |
|
kde=False, |
|
rug=False, |
|
norm_hist=False, |
|
axlabel='Length' |
|
) |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_0_measured_AA_len_Dist.jpg' |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
print("keys in df: ", protein_df.keys()) |
|
|
|
|
|
|
|
|
|
X=[] |
|
for i in range (len(seqs)): |
|
X.append([protein_df['AH'][i],protein_df['BS'][i],protein_df['T'][i], |
|
protein_df['UNSTRUCTURED'][i],protein_df['BETABRIDGE'][i], |
|
protein_df['310HELIX'][i],protein_df['PIHELIX'][i],protein_df['BEND'][i] |
|
]) |
|
|
|
X=np.array(X) |
|
print ("sample X data", X[0]) |
|
|
|
fig_handle = sns.distplot( |
|
X[:,0],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='AH') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_1_AH_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,1],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='BS') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_2_BS_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,2],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='T') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_3_T_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,3],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='UNSTRUCTURED') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_4_UNSTRUCTURED_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,4],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='BETABRIDGE') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_5_BETABRIDGE_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
fig_handle = sns.distplot( |
|
X[:,5],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='310HELIX') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_6_310HELIX_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,6],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='PIHELIX') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_7_PIHELIX_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,7],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='BEND') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_8_BEND_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
seqs = protein_df.Sequence.values |
|
|
|
|
|
if tokenizer_y==None: |
|
tokenizer_y = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) |
|
tokenizer_y.fit_on_texts(seqs) |
|
|
|
y_data = tokenizer_y.texts_to_sequences(seqs) |
|
|
|
y_data= sequence.pad_sequences(y_data, maxlen=max_length, padding='post', truncating='post') |
|
|
|
print ("#################################") |
|
print ("DICTIONARY y_data") |
|
dictt=tokenizer_y.get_config() |
|
print (dictt) |
|
num_words = len(tokenizer_y.word_index) + 1 |
|
|
|
print ("################## max token: ",num_words ) |
|
|
|
|
|
print ("TEST REVERSE: ") |
|
print ("y data shape: ", y_data.shape) |
|
y_data_reversed=tokenizer_y.sequences_to_texts (y_data) |
|
|
|
for iii in range (len(y_data_reversed)): |
|
y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
|
|
|
print ("Element 0", y_data_reversed[0]) |
|
|
|
print ("Number of y samples",len (y_data_reversed) ) |
|
print ("Original: ", y_data[:3,:]) |
|
print ("Original Seq : ", seqs[:3]) |
|
|
|
print ("REVERSED TEXT 0..2: ", y_data_reversed[0:3]) |
|
|
|
print ("Len 0 as example: ", len (y_data_reversed[0]) ) |
|
print ("Len 2 as example: ", len (y_data_reversed[2]) ) |
|
|
|
if maxdata<y_data.shape[0]: |
|
print ('select subset...', maxdata ) |
|
X=X[:maxdata] |
|
y_data=y_data[:maxdata] |
|
print ("new shapes, X and y:\n ", X.shape, y_data.shape) |
|
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
X, y_data , |
|
test_size=split, |
|
random_state=235 |
|
) |
|
|
|
train_dataset = RegressionDataset( |
|
torch.from_numpy(X_train).float(), |
|
torch.from_numpy(y_train).float()/ynormfac |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig_handle = sns.distplot( |
|
torch.from_numpy(y_train.flatten()), |
|
bins=42,kde=False, |
|
rug=False,norm_hist=False,axlabel='y: AA labels') |
|
fig = fig_handle.get_figure() |
|
outname=save_dir+'CSV_9_YTrain_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
test_dataset = RegressionDataset( |
|
torch.from_numpy(X_test).float(), |
|
torch.from_numpy(y_test).float()/ynormfac |
|
) |
|
|
|
train_loader = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size_, |
|
shuffle=True |
|
) |
|
train_loader_noshuffle = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size_, |
|
shuffle=False |
|
) |
|
test_loader = DataLoader( |
|
dataset=test_dataset, |
|
batch_size=batch_size_ |
|
) |
|
|
|
return train_loader, train_loader_noshuffle, test_loader, tokenizer_y, tokenizer_X |
|
|
|
|
|
|
|
|
|
|
|
def load_data_set_seq2seq_SecStr_ModelA_pLM ( |
|
file_path, |
|
PKeys=None, |
|
CKeys=None, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
min_length=PKeys['min_AA_seq_len'] |
|
max_length=PKeys['max_AA_seq_len'] |
|
batch_size_=PKeys['batch_size'] |
|
|
|
maxdata=PKeys['maxdata'] |
|
tokenizer_y=PKeys['tokenizer_y'] |
|
split=PKeys['testset_ratio'] |
|
ynormfac=PKeys['ynormfac'] |
|
|
|
Xnormfac=PKeys['Xnormfac'] |
|
tokenizer_X=PKeys['tokenizer_X'] |
|
|
|
IF_SaveFig = CKeys['SlientRun'] |
|
save_dir = PKeys['data_dir'] |
|
|
|
|
|
|
|
|
|
protein_df=pd.read_csv(file_path) |
|
protein_df |
|
|
|
protein_df.describe() |
|
|
|
df_isnull = pd.DataFrame(round((protein_df.isnull().sum().sort_values(ascending=False)/protein_df.shape[0])*100,1)).reset_index() |
|
df_isnull.columns = ['Columns', '% of Missing Data'] |
|
df_isnull.style.format({'% of Missing Data': lambda x:'{:.1%}'.format(abs(x))}) |
|
cm = sns.light_palette("skyblue", as_cmap=True) |
|
df_isnull = df_isnull.style.background_gradient(cmap=cm) |
|
df_isnull |
|
|
|
|
|
protein_df['Seq_Len'] = protein_df.apply(lambda x: len(x['Seq']), axis=1) |
|
|
|
print("Screen the sequence length...") |
|
protein_df.drop(protein_df[protein_df['Seq_Len'] >max_length-2].index, inplace = True) |
|
protein_df.drop(protein_df[protein_df['Seq_Len'] <min_length].index, inplace = True) |
|
|
|
print("# of data point: ", protein_df.shape) |
|
print(protein_df.head(6)) |
|
protein_df=protein_df.reset_index(drop=True) |
|
|
|
seqs = protein_df.Sequence.values |
|
|
|
test_seqs = seqs[:1] |
|
|
|
lengths = [len(s) for s in seqs] |
|
|
|
print(protein_df.shape) |
|
print(protein_df.head(6)) |
|
|
|
min_length_measured = min (lengths) |
|
max_length_measured = max (lengths) |
|
|
|
print ("Measured seq length: min and max") |
|
print (min_length_measured,max_length_measured) |
|
|
|
|
|
fig_handle = sns.distplot( |
|
lengths, |
|
bins=50, |
|
kde=False, |
|
rug=False, |
|
norm_hist=False, |
|
axlabel='Length' |
|
) |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_0_measured_AA_len_Dist.jpg' |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
print("keys in df: ", protein_df.keys()) |
|
|
|
|
|
|
|
|
|
X=[] |
|
for i in range (len(seqs)): |
|
X.append([protein_df['AH'][i],protein_df['BS'][i],protein_df['T'][i], |
|
protein_df['UNSTRUCTURED'][i],protein_df['BETABRIDGE'][i], |
|
protein_df['310HELIX'][i],protein_df['PIHELIX'][i],protein_df['BEND'][i] |
|
]) |
|
|
|
X=np.array(X) |
|
print ("X.shape: ", X.shape, "[batch, text_len=8]") |
|
print ("sample X data", X[0]) |
|
|
|
fig_handle = sns.distplot( |
|
X[:,0],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='AH') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_1_AH_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,1],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='BS') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_2_BS_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,2],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='T') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_3_T_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,3],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='UNSTRUCTURED') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_4_UNSTRUCTURED_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,4],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='BETABRIDGE') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_5_BETABRIDGE_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
fig_handle = sns.distplot( |
|
X[:,5],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='310HELIX') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_6_310HELIX_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,6],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='PIHELIX') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_7_PIHELIX_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
fig_handle = sns.distplot( |
|
X[:,7],bins=50,kde=False, |
|
rug=False,norm_hist=False,axlabel='BEND') |
|
fig = fig_handle.get_figure() |
|
outname = save_dir+'CSV_8_BEND_Dist.jpg' |
|
plt.ylabel("Counts") |
|
plt.savefig(outname, dpi=200) |
|
if IF_SaveFig==1: |
|
pass |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
seqs = protein_df.Sequence.values |
|
|
|
if PKeys['ESM-2_Model']=='esm2_t33_650M_UR50D': |
|
|
|
esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
elif PKeys['ESM-2_Model']=='esm2_t12_35M_UR50D': |
|
esm_model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
else: |
|
pass |
|
|
|
|
|
print("esm_alphabet.use_msa: ", esm_alphabet.use_msa) |
|
print("# of tokens in AA alphabet: ", len_toks) |
|
|
|
esm_batch_converter = esm_alphabet.get_batch_converter( |
|
truncation_seq_length=PKeys['max_AA_seq_len']-2 |
|
) |
|
esm_model.eval() |
|
|
|
|
|
seqs_ext=[] |
|
for i in range(len(seqs)): |
|
seqs_ext.append( |
|
(" ", seqs[i]) |
|
) |
|
|
|
_, y_strs, y_data = esm_batch_converter(seqs_ext) |
|
y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) |
|
|
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'AA token (esm)': np.array(y_data).flatten()}), |
|
x='AA token (esm)', bins=np.array([i-0.5 for i in range(0,33+3,1)]) |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
|
|
fig_handle.set_xlim(-1, 33+1) |
|
|
|
outname=save_dir+'CSV_9_AACode_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print ("#################################") |
|
print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) |
|
print ("################## y max token: ",len_toks ) |
|
|
|
|
|
print ("y data shape: ", y_data.shape) |
|
print ("X data shape: ", X.shape) |
|
|
|
|
|
print ("TEST REVERSE: ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) |
|
|
|
print ("Element 0", y_data_reversed[0]) |
|
|
|
print ("Number of y samples",len (y_data_reversed) ) |
|
|
|
for iii in [0,2,6]: |
|
print("Ori and REVERSED SEQ: ", iii) |
|
print(seqs[iii]) |
|
print(y_data_reversed[iii]) |
|
|
|
|
|
|
|
|
|
print ("Len 0 as example: ", len (y_data_reversed[0]) ) |
|
print ("CHeck ori: ", len (seqs[0]) ) |
|
print ("Len 2 as example: ", len (y_data_reversed[2]) ) |
|
print ("CHeck ori: ", len (seqs[2]) ) |
|
|
|
|
|
if maxdata<y_data.shape[0]: |
|
print ('select subset...', maxdata ) |
|
X=X[:maxdata] |
|
y_data=y_data[:maxdata] |
|
print ("new shapes, X and y:\n ", X.shape, y_data.shape) |
|
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
X, y_data , |
|
test_size=split, |
|
random_state=235 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = RegressionDataset( |
|
torch.from_numpy(X_train).float(), |
|
y_train, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'AA token (esm)': np.array(y_train).flatten()}), |
|
x='AA token (esm)', bins=np.array([i-0.5 for i in range(0,33+3,1)]) |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
|
|
fig_handle.set_xlim(-1, 33+1) |
|
|
|
outname=save_dir+'CSV_10_TrainSet_AACode_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_dataset = RegressionDataset( |
|
torch.from_numpy(X_test).float(), |
|
y_test, |
|
) |
|
|
|
train_loader = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size_, |
|
shuffle=True |
|
) |
|
train_loader_noshuffle = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size_, |
|
shuffle=False |
|
) |
|
test_loader = DataLoader( |
|
dataset=test_dataset, |
|
batch_size=batch_size_ |
|
) |
|
|
|
return train_loader, train_loader_noshuffle, test_loader, tokenizer_y, tokenizer_X |
|
|
|
|
|
|
|
|
|
def load_data_set_text2seq_MD_ModelA ( |
|
protein_df, |
|
PKeys=None, |
|
CKeys=None, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
store_path = PKeys['data_dir'] |
|
IF_SaveFig = CKeys['SlientRun'] |
|
|
|
X_Key = PKeys['X_Key'] |
|
|
|
max_AA_len = PKeys['max_AA_seq_len'] |
|
norm_fac_force = PKeys['Xnormfac'][0] |
|
norm_fac_energ = PKeys['Xnormfac'][1] |
|
|
|
AA_seq_normfac = PKeys['ynormfac'] |
|
tokenizer_X = PKeys['tokenizer_X'] |
|
tokenizer_y = PKeys['tokenizer_y'] |
|
assert tokenizer_X==None, "tokenizer_X should be None" |
|
assert tokenizer_y==None, "Check tokenizer_y" |
|
|
|
batch_size = PKeys['batch_size'] |
|
TestSet_ratio = PKeys['testset_ratio'] |
|
|
|
maxdata=PKeys['maxdata'] |
|
|
|
|
|
print("======================================================") |
|
print("1. work on X data") |
|
print("======================================================") |
|
|
|
X = [] |
|
for i in range(len(protein_df)): |
|
X.append( |
|
[protein_df[X_Key[0]][i], |
|
protein_df[X_Key[1]][i]] |
|
) |
|
X=np.array(X) |
|
print('Input tokenized X.dim: ', X.shape) |
|
print('Use ', X_Key) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Normalize the X input: ...") |
|
print('Normalized factor for the force: ', norm_fac_force) |
|
print('Normalized factor for the energy: ', norm_fac_energ) |
|
|
|
|
|
|
|
|
|
X[:,0]=X[:,0]/norm_fac_force |
|
X[:,1]=X[:,1]/norm_fac_energ |
|
|
|
|
|
fig = plt.figure() |
|
sns.distplot( |
|
X[:,0],bins=50,kde=False, |
|
rug=False,norm_hist=False, |
|
|
|
) |
|
sns.distplot( |
|
X[:,1],bins=50,kde=False, |
|
rug=False,norm_hist=False, |
|
|
|
) |
|
|
|
plt.legend(['Normalized Fmax','Normalized Unfolding Ene.']) |
|
plt.ylabel('Counts') |
|
outname = store_path+'CSV_5_AfterNorm_SMD_FmaxEne.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
print("======================================================") |
|
print("2. work on Y data") |
|
print("======================================================") |
|
|
|
|
|
seqs = protein_df.AA.values |
|
|
|
if tokenizer_y==None: |
|
tokenizer_y = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) |
|
tokenizer_y.fit_on_texts(seqs) |
|
|
|
|
|
y_data = tokenizer_y.texts_to_sequences(seqs) |
|
|
|
y_data= sequence.pad_sequences( |
|
y_data, maxlen=max_AA_len, |
|
padding='post', truncating='post') |
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), |
|
x='AA code', bins=np.array([i-0.5 for i in range(0,20+3,1)]) |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
fig_handle.set_xlim(-1, 21) |
|
|
|
outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
print ("#################################") |
|
print ("DICTIONARY y_data") |
|
dictt=tokenizer_y.get_config() |
|
print (dictt) |
|
num_words = len(tokenizer_y.word_index) + 1 |
|
print ("################## y max token: ",num_words ) |
|
|
|
|
|
print ("TEST REVERSE: ") |
|
y_data_reversed=tokenizer_y.sequences_to_texts (y_data) |
|
|
|
for iii in range (len(y_data_reversed)): |
|
y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
|
|
|
print ("Element 0", y_data_reversed[0]) |
|
print ("Number of y samples",len (y_data_reversed) ) |
|
|
|
for iii in [0,2,6]: |
|
print("Ori and REVERSED SEQ: ", iii) |
|
print(seqs[iii]) |
|
print(y_data_reversed[iii]) |
|
|
|
|
|
|
|
|
|
print ("Len 0 as example: ", len (y_data_reversed[0]) ) |
|
print ("CHeck ori: ", len (seqs[0]) ) |
|
print ("Len 2 as example: ", len (y_data_reversed[2]) ) |
|
print ("CHeck ori: ", len (seqs[2]) ) |
|
|
|
if maxdata<y_data.shape[0]: |
|
print ('select subset...', maxdata ) |
|
|
|
|
|
X=X[:maxdata] |
|
y_data=y_data[:maxdata] |
|
print ("new shapes (X, y_data): ", X.shape, y_data.shape) |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
X, y_data, |
|
test_size=TestSet_ratio, |
|
random_state=235) |
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = RegressionDataset( |
|
torch.from_numpy(X_train).float(), |
|
torch.from_numpy(y_train).float()/AA_seq_normfac |
|
) |
|
|
|
test_dataset = RegressionDataset( |
|
torch.from_numpy(X_test).float(), |
|
torch.from_numpy(y_test).float()/AA_seq_normfac |
|
) |
|
|
|
train_loader = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True) |
|
train_loader_noshuffle = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=False) |
|
test_loader = DataLoader( |
|
dataset=test_dataset, |
|
batch_size=batch_size) |
|
|
|
|
|
return train_loader, train_loader_noshuffle, \ |
|
test_loader, tokenizer_y, tokenizer_X |
|
|
|
|
|
|
|
|
|
def load_data_set_text2seq_MD_ModelA_pLM ( |
|
protein_df, |
|
PKeys=None, |
|
CKeys=None, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
store_path = PKeys['data_dir'] |
|
IF_SaveFig = CKeys['SlientRun'] |
|
|
|
X_Key = PKeys['X_Key'] |
|
|
|
max_AA_len = PKeys['max_AA_seq_len'] |
|
norm_fac_force = PKeys['Xnormfac'][0] |
|
norm_fac_energ = PKeys['Xnormfac'][1] |
|
|
|
AA_seq_normfac = PKeys['ynormfac'] |
|
tokenizer_X = PKeys['tokenizer_X'] |
|
tokenizer_y = PKeys['tokenizer_y'] |
|
assert tokenizer_X==None, "tokenizer_X should be None" |
|
assert tokenizer_y==None, "Check tokenizer_y" |
|
|
|
batch_size = PKeys['batch_size'] |
|
TestSet_ratio = PKeys['testset_ratio'] |
|
|
|
maxdata=PKeys['maxdata'] |
|
|
|
|
|
print("======================================================") |
|
print("1. work on X data") |
|
print("======================================================") |
|
|
|
X = [] |
|
for i in range(len(protein_df)): |
|
X.append( |
|
[protein_df[X_Key[0]][i], |
|
protein_df[X_Key[1]][i]] |
|
) |
|
X=np.array(X) |
|
print('Input tokenized X.dim: ', X.shape) |
|
print('Use ', X_Key) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Normalize the X input: ...") |
|
print('Normalized factor for the force: ', norm_fac_force) |
|
print('Normalized factor for the energy: ', norm_fac_energ) |
|
|
|
|
|
|
|
|
|
X[:,0]=X[:,0]/norm_fac_force |
|
X[:,1]=X[:,1]/norm_fac_energ |
|
|
|
|
|
fig = plt.figure() |
|
sns.distplot( |
|
X[:,0],bins=50,kde=False, |
|
rug=False,norm_hist=False, |
|
|
|
) |
|
sns.distplot( |
|
X[:,1],bins=50,kde=False, |
|
rug=False,norm_hist=False, |
|
|
|
) |
|
|
|
plt.legend(['Normalized Fmax','Normalized Unfolding Ene.']) |
|
plt.ylabel('Counts') |
|
outname = store_path+'CSV_5_AfterNorm_SMD_FmaxEne.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
|
|
plt.close() |
|
|
|
print("======================================================") |
|
print("2. work on Y data: AA Sequence with pLM") |
|
print("======================================================") |
|
|
|
|
|
seqs = protein_df.AA.values |
|
|
|
|
|
|
|
if PKeys['ESM-2_Model']=='esm2_t33_650M_UR50D': |
|
|
|
esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
elif PKeys['ESM-2_Model']=='esm2_t12_35M_UR50D': |
|
esm_model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
else: |
|
print("protein language model is not defined.") |
|
|
|
|
|
print("esm_alphabet.use_msa: ", esm_alphabet.use_msa) |
|
print("# of tokens in AA alphabet: ", len_toks) |
|
|
|
esm_batch_converter = esm_alphabet.get_batch_converter( |
|
truncation_seq_length=PKeys['max_AA_seq_len']-2 |
|
) |
|
esm_model.eval() |
|
|
|
|
|
seqs_ext=[] |
|
for i in range(len(seqs)): |
|
seqs_ext.append( |
|
(" ", seqs[i]) |
|
) |
|
|
|
_, y_strs, y_data = esm_batch_converter(seqs_ext) |
|
y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) |
|
|
|
print ("y_data: ", y_data.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig_handle = sns.histplot( |
|
data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), |
|
x='AA code', |
|
bins=np.array([i-0.5 for i in range(0,33+3,1)]), |
|
|
|
) |
|
fig = fig_handle.get_figure() |
|
fig_handle.set_xlim(-1, 33+1) |
|
|
|
outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' |
|
if IF_SaveFig==1: |
|
plt.savefig(outname, dpi=200) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print ("#################################") |
|
print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) |
|
print ("################## y max token: ",len_toks ) |
|
|
|
|
|
print ("TEST REVERSE: ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) |
|
|
|
print ("Element 0", y_data_reversed[0]) |
|
print ("Number of y samples",len (y_data_reversed) ) |
|
|
|
for iii in [0,2,6]: |
|
print("Ori and REVERSED SEQ: ", iii) |
|
print(seqs[iii]) |
|
print(y_data_reversed[iii]) |
|
|
|
|
|
|
|
|
|
print ("Len 0 as example: ", len (y_data_reversed[0]) ) |
|
print ("CHeck ori: ", len (seqs[0]) ) |
|
print ("Len 2 as example: ", len (y_data_reversed[2]) ) |
|
print ("CHeck ori: ", len (seqs[2]) ) |
|
|
|
if maxdata<y_data.shape[0]: |
|
print ('select subset...', maxdata ) |
|
|
|
|
|
X=X[:maxdata] |
|
y_data=y_data[:maxdata] |
|
print ("new shapes (X, y_data): ", X.shape, y_data.shape) |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
X, y_data, |
|
test_size=TestSet_ratio, |
|
random_state=235) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = RegressionDataset( |
|
torch.from_numpy(X_train).float(), |
|
y_train |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_dataset = RegressionDataset( |
|
torch.from_numpy(X_test).float(), |
|
y_test, |
|
) |
|
|
|
train_loader = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True) |
|
train_loader_noshuffle = DataLoader( |
|
dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=False) |
|
test_loader = DataLoader( |
|
dataset=test_dataset, |
|
batch_size=batch_size) |
|
|
|
|
|
return train_loader, train_loader_noshuffle, \ |
|
test_loader, tokenizer_y, tokenizer_X |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|