from tensorflow.keras.preprocessing import text, sequence from tensorflow.keras.preprocessing.text import Tokenizer from torch.utils.data import DataLoader,Dataset import pandas as pd import seaborn as sns import torchvision import matplotlib.pyplot as plt import numpy as np from torch import nn from torch import optim import torch.nn.functional as F from torchvision import datasets, transforms, models import torch.optim as optim from torch.optim.lr_scheduler import ExponentialLR, StepLR from functools import partial, wraps from sklearn.model_selection import train_test_split from sklearn.preprocessing import QuantileTransformer from sklearn.preprocessing import RobustScaler from matplotlib.ticker import MaxNLocator import torch import esm # ============================================================ # convert csv into df # ============================================================ class RegressionDataset(Dataset): def __init__(self, X_data, y_data): self.X_data = X_data self.y_data = y_data def __getitem__(self, index): return self.X_data[index], self.y_data[index] def __len__ (self): return len(self.X_data) # padding a list using a given value def pad_a_np_arr(x0,add_x,n_len): n0 = len(x0) # print(n0) x1 = x0.copy() if n0 # x1 = [add_x]+x1 # somehow, this one doesn't work # # print(x1) # # print('x1 len: ',len(x1) ) # n0 = len(x1) # # # if n0 # x1 = [add_x]+x1 # somehow, this one doesn't work # print(x1) # print('x1 len: ',len(x1) ) n0 = len(x1) # if n0max_AASeq_len-2].index, inplace = True ) protein_df.drop( protein_df[protein_df['seq_len'] max_used_Smo_F].index, inplace = True ) # protein_df.drop( # protein_df[protein_df['seq_len'] max_AASeq_len-2].index, inplace = True ) protein_df.drop( protein_df[protein_df['seq_len'] max_used_Smo_F].index, inplace = True ) # protein_df.drop( # protein_df[protein_df['seq_len'] and esm_batch_converter = esm_alphabet.get_batch_converter( truncation_seq_length=PKeys['max_AA_seq_len']-2 ) esm_model.eval() # disables dropout for deterministic results # prepare seqs for the "esm_batch_converter..." # add dummy labels seqs_ext=[] for i in range(len(seqs)): seqs_ext.append( (" ", seqs[i]) ) # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext) _, y_strs, y_data = esm_batch_converter(seqs_ext) y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) # print(batch_tokens.shape) print ("y_data.dim: ", y_data.dtype) # # -- # # tokenizer_y = None # if tokenizer_y==None: # tokenizer_y = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) # tokenizer_y.fit_on_texts(seqs) # #y_data = tokenizer_y.texts_to_sequences(y_data) # y_data = tokenizer_y.texts_to_sequences(seqs) # y_data= sequence.pad_sequences( # y_data, maxlen=max_AA_len, # padding='post', truncating='post') fig_handle = sns.histplot( data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), x='AA code', bins=np.array([i-0.5 for i in range(0,33+3,1)]), # np.array([i-0.5 for i in range(0,20+3,1)]) # binwidth=1, ) fig = fig_handle.get_figure() fig_handle.set_xlim(-1, 33+1) # fig_handle.set_ylim(0, 100000) outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' if IF_SaveFig==1: plt.savefig(outname, dpi=200) else: plt.show() plt.close() # ----------------------------------------------------------- # print ("#################################") # print ("DICTIONARY y_data") # dictt=tokenizer_y.get_config() # print (dictt) # num_words = len(tokenizer_y.word_index) + 1 # print ("################## y max token: ",num_words ) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ print ("#################################") print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) print ("################## y max token: ",len_toks ) #revere print ("TEST REVERSE: ") # # -------------------------------------------------------------- # y_data_reversed=tokenizer_y.sequences_to_texts (y_data) # for iii in range (len(y_data_reversed)): # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # assume y_data is reversiable y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) print ("Element 0", y_data_reversed[0]) print ("Number of y samples",len (y_data_reversed) ) for iii in [0,2,6]: print("Ori and REVERSED SEQ: ", iii) print(seqs[iii]) print(y_data_reversed[iii]) # print ("Original: ", y_data[:3,:]) # print ("REVERSED TEXT 0..2: ", y_data_reversed[0:3]) print ("Len 0 as example: ", len (y_data_reversed[0]) ) print ("CHeck ori: ", len (seqs[0]) ) print ("Len 2 as example: ", len (y_data_reversed[2]) ) print ("CHeck ori: ", len (seqs[2]) ) if maxdata 0aaa0 (add one 0 at the beginning) # # -- # y_data= sequence.pad_sequences( # y_data, maxlen=max_AA_len, # padding='post', truncating='post') # ++ y_data= sequence.pad_sequences( y_data, maxlen=max_AA_len-1, padding='post', truncating='post', value=0.0, ) # add one 0 at the begining y_data= sequence.pad_sequences( y_data, maxlen=max_AA_len, padding='pre', truncating='pre', value=0.0, ) len_toks = len(tokenizer_y.word_index) + 1 else: # ++ for pLM: esm # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ print("pLM model: ", PKeys['ESM-2_Model']) if PKeys['ESM-2_Model']=='esm2_t33_650M_UR50D': # print('Debug block') # embed dim: 1280 esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() len_toks=len(esm_alphabet.all_toks) elif PKeys['ESM-2_Model']=='esm2_t12_35M_UR50D': # embed dim: 480 esm_model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() len_toks=len(esm_alphabet.all_toks) elif PKeys['ESM-2_Model']=='esm2_t36_3B_UR50D': # embed dim: 2560 esm_model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() len_toks=len(esm_alphabet.all_toks) elif PKeys['ESM-2_Model']=='esm2_t30_150M_UR50D': # embed dim: 640 esm_model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() len_toks=len(esm_alphabet.all_toks) else: print("protein language model is not defined.") # pass # for check print("esm_alphabet.use_msa: ", esm_alphabet.use_msa) print("# of tokens in AA alphabet: ", len_toks) # need to save 2 positions for and esm_batch_converter = esm_alphabet.get_batch_converter( truncation_seq_length=PKeys['max_AA_seq_len']-2 ) esm_model.eval() # disables dropout for deterministic results # prepare seqs for the "esm_batch_converter..." # add dummy labels seqs_ext=[] for i in range(len(seqs)): seqs_ext.append( (" ", seqs[i]) ) # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext) _, y_strs, y_data = esm_batch_converter(seqs_ext) y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) # # NEED to check the size of y_data # need to dealwith if y_data are only shorter sequences # need to add padding with a value, int (1) current_seq_len = y_data.shape[1] print("current seq batch len: ", current_seq_len) missing_num_pad = PKeys['max_AA_seq_len']-current_seq_len if missing_num_pad>0: print("extra padding is added to match the target seq input length...") # padding is needed y_data = F.pad( y_data, (0, missing_num_pad), "constant", esm_alphabet.padding_idx ) else: print("No extra padding is needed") # ---------------------------------------------------------------------------------- # print(batch_tokens.shape) print ("y_data.dim: ", y_data.shape) print ("y_data.type: ", y_data.type) # # -- # # tokenizer_y = None # if tokenizer_y==None: # tokenizer_y = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) # tokenizer_y.fit_on_texts(seqs) # #y_data = tokenizer_y.texts_to_sequences(y_data) # y_data = tokenizer_y.texts_to_sequences(seqs) # y_data= sequence.pad_sequences( # y_data, maxlen=max_AA_len, # padding='post', truncating='post') fig_handle = sns.histplot( data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), x='AA code', bins=np.array([i-0.5 for i in range(0,33+3,1)]), # np.array([i-0.5 for i in range(0,20+3,1)]) # binwidth=1, ) fig = fig_handle.get_figure() fig_handle.set_xlim(-1, 33+1) # fig_handle.set_ylim(0, 100000) outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' if IF_SaveFig==1: plt.savefig(outname, dpi=200) else: plt.show() plt.close() # ----------------------------------------------------------- # print ("#################################") # print ("DICTIONARY y_data") # dictt=tokenizer_y.get_config() # print (dictt) # num_words = len(tokenizer_y.word_index) + 1 # print ("################## y max token: ",num_words ) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ print ("#################################") print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) print ("################## y max token: ",len_toks ) #revere print ("TEST REVERSE: ") if PKeys['ESM-2_Model']=='trivial': # -------------------------------------------------------------- y_data_reversed=tokenizer_y.sequences_to_texts (y_data) for iii in range (len(y_data_reversed)): y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") else: # for ESM models # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # assume y_data is reversiable y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) print ("Element 0", y_data_reversed[0]) print ("Number of y samples",len (y_data_reversed) ) for iii in [0,2,6]: print("Ori and REVERSED SEQ: ", iii) print(seqs[iii]) print(y_data_reversed[iii]) # print ("Original: ", y_data[:3,:]) # print ("REVERSED TEXT 0..2: ", y_data_reversed[0:3]) print ("Len 0 as example: ", len (y_data_reversed[0]) ) print ("CHeck ori: ", len (seqs[0]) ) print ("Len 2 as example: ", len (y_data_reversed[2]) ) print ("CHeck ori: ", len (seqs[2]) ) if maxdatamax_AASeq_len-2].index, inplace = True) protein_df.drop(protein_df[protein_df['Seq_Len'] max_AASeq_len-2].index, inplace = True) protein_df.drop(protein_df[protein_df['Seq_Len'] max_length-2].index, inplace = True) protein_df.drop(protein_df[protein_df['Seq_Len'] max_length-2].index, inplace = True) protein_df.drop(protein_df[protein_df['Seq_Len'] and esm_batch_converter = esm_alphabet.get_batch_converter( truncation_seq_length=PKeys['max_AA_seq_len']-2 ) esm_model.eval() # disables dropout for deterministic results # prepare seqs for the "esm_batch_converter..." # add dummy labels seqs_ext=[] for i in range(len(seqs)): seqs_ext.append( (" ", seqs[i]) ) # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext) _, y_strs, y_data = esm_batch_converter(seqs_ext) y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) # print(batch_tokens.shape) print ("y_data: ", y_data.dtype) # # -- # # tokenizer_y = None # if tokenizer_y==None: # tokenizer_y = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) # tokenizer_y.fit_on_texts(seqs) # #y_data = tokenizer_y.texts_to_sequences(y_data) # y_data = tokenizer_y.texts_to_sequences(seqs) # y_data= sequence.pad_sequences( # y_data, maxlen=max_AA_len, # padding='post', truncating='post') fig_handle = sns.histplot( data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), x='AA code', bins=np.array([i-0.5 for i in range(0,33+3,1)]), # np.array([i-0.5 for i in range(0,20+3,1)]) # binwidth=1, ) fig = fig_handle.get_figure() fig_handle.set_xlim(-1, 33+1) # fig_handle.set_ylim(0, 100000) outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' if IF_SaveFig==1: plt.savefig(outname, dpi=200) else: plt.show() plt.close() # # -------------------------------------------------------- # print ("#################################") # print ("DICTIONARY y_data") # dictt=tokenizer_y.get_config() # print (dictt) # num_words = len(tokenizer_y.word_index) + 1 # print ("################## y max token: ",num_words ) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ print ("#################################") print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) print ("################## y max token: ",len_toks ) #revere print ("TEST REVERSE: ") # # ---------------------------------------------------------------- # y_data_reversed=tokenizer_y.sequences_to_texts (y_data) # for iii in range (len(y_data_reversed)): # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # assume y_data is reversiable y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) print ("Element 0", y_data_reversed[0]) print ("Number of y samples",len (y_data_reversed) ) for iii in [0,2,6]: print("Ori and REVERSED SEQ: ", iii) print(seqs[iii]) print(y_data_reversed[iii]) # print ("Original: ", y_data[:3,:]) # print ("REVERSED TEXT 0..2: ", y_data_reversed[0:3]) print ("Len 0 as example: ", len (y_data_reversed[0]) ) print ("CHeck ori: ", len (seqs[0]) ) print ("Len 2 as example: ", len (y_data_reversed[2]) ) print ("CHeck ori: ", len (seqs[2]) ) if maxdatamax_length-2].index, inplace = True) # protein_df.drop(protein_df[protein_df['Seq_Len']