|
|
|
|
|
|
|
import os |
|
from scipy.interpolate import CubicSpline, PchipInterpolator, Akima1DInterpolator |
|
import numpy as np |
|
import math |
|
import matplotlib.pyplot as plt |
|
|
|
from Bio.PDB import PDBParser |
|
from Bio.PDB.DSSP import DSSP |
|
from Bio.PDB import PDBList |
|
|
|
import torch |
|
from einops import rearrange |
|
import esm |
|
|
|
|
|
def create_path(this_path): |
|
if not os.path.exists(this_path): |
|
print('Creating the given path...') |
|
os.mkdir (this_path) |
|
path_stat = 1 |
|
print('Done.') |
|
else: |
|
print('The given path already exists!') |
|
path_stat = 2 |
|
return path_stat |
|
|
|
|
|
|
|
|
|
def params (model): |
|
pytorch_total_params = sum(p.numel() for p in model.parameters()) |
|
pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
print ("Total parameters: ", pytorch_total_params," trainable parameters: ", pytorch_total_params_trainable) |
|
|
|
|
|
|
|
def prepare_UNet_keys(write_dict): |
|
|
|
Full_Keys=['dim', 'text_embed_dim', 'num_resnet_blocks', 'cond_dim', 'num_image_tokens', 'num_time_tokens', 'learned_sinu_pos_emb_dim', 'out_dim', 'dim_mults', 'cond_images_channels', 'channels', 'channels_out', 'attn_dim_head', 'attn_heads', 'ff_mult', 'lowres_cond', 'layer_attns', 'layer_attns_depth', 'layer_attns_add_text_cond', 'attend_at_middle', 'layer_cross_attns', 'use_linear_attn', 'use_linear_cross_attn', 'cond_on_text', 'max_text_len', 'init_dim', 'resnet_groups', 'init_conv_kernel_size', 'init_cross_embed', 'init_cross_embed_kernel_sizes', 'cross_embed_downsample', 'cross_embed_downsample_kernel_sizes', 'attn_pool_text', 'attn_pool_num_latents', 'dropout', 'memory_efficient', 'init_conv_to_final_conv_residual', 'use_global_context_attn', 'scale_skip_connection', 'final_resnet_block', 'final_conv_kernel_size', 'cosine_sim_attn', 'self_cond', 'combine_upsample_fmaps', 'pixel_shuffle_upsample', 'beginning_and_final_conv_present'] |
|
|
|
PKeys={} |
|
for key in Full_Keys: |
|
PKeys[key]=None |
|
|
|
for write_key in write_dict.keys(): |
|
if write_key in PKeys.keys(): |
|
PKeys[write_key]=write_dict[write_key] |
|
else: |
|
print("Wrong key found: ", write_key) |
|
|
|
return PKeys |
|
|
|
def prepare_ModelB_keys(write_dict): |
|
Full_Keys=['timesteps', 'dim', 'pred_dim', 'loss_type', 'elucidated', 'padding_idx', 'cond_dim', 'text_embed_dim', 'input_tokens', 'sequence_embed', 'embed_dim_position', 'max_text_len', 'cond_images_channels', 'max_length', 'device'] |
|
|
|
PKeys={} |
|
for key in Full_Keys: |
|
PKeys[key]=None |
|
|
|
for write_key in write_dict.keys(): |
|
if write_key in PKeys.keys(): |
|
PKeys[write_key]=write_dict[write_key] |
|
else: |
|
print("Wrong key found: ", write_key) |
|
|
|
return PKeys |
|
|
|
def modify_keys(old_dict,write_dict): |
|
new_dict = old_dict.copy() |
|
for w_key in write_dict.keys(): |
|
if w_key in old_dict.keys(): |
|
new_dict[w_key]=write_dict[w_key] |
|
else: |
|
print("Alien key found: ", w_key) |
|
return new_dict |
|
|
|
|
|
|
|
|
|
def mixing_two_FORCE_for_AA_Len(NGap1,Force1,NGap2,Force2,LenAA,mix_fac): |
|
N = np.amax([len(NGap1), len(NGap2)]) |
|
N_Base = math.ceil(N*2) |
|
fun_PI_0 = PchipInterpolator(NGap1,Force1) |
|
fun_PI_1 = PchipInterpolator(NGap2,Force2) |
|
xx=np.linspace(0,1,N_Base) |
|
yy=fun_PI_0(xx)*mix_fac+fun_PI_1(xx)*(1-mix_fac) |
|
fun_PI = PchipInterpolator(xx,yy) |
|
|
|
x1=np.linspace(0,1,LenAA+1) |
|
y1=fun_PI(x1) |
|
return fun_PI, x1, y1 |
|
|
|
|
|
|
|
|
|
def get_Model_A_error (fname, cond, plotit=True, ploterror=False): |
|
|
|
sec_structure,sec_structure_3state, sequence=get_DSSP_result (fname) |
|
sscount=[] |
|
length = len (sec_structure) |
|
sscount.append (sec_structure.count('H')/length) |
|
sscount.append (sec_structure.count('E')/length) |
|
sscount.append (sec_structure.count('T')/length) |
|
sscount.append (sec_structure.count('~')/length) |
|
sscount.append (sec_structure.count('B')/length) |
|
sscount.append (sec_structure.count('G')/length) |
|
sscount.append (sec_structure.count('I')/length) |
|
sscount.append (sec_structure.count('S')/length) |
|
sscount=np.asarray (sscount) |
|
|
|
error=np.abs(sscount-cond) |
|
print ("Abs error per SS structure type (H, E, T, ~, B, G, I S): ", error) |
|
|
|
if ploterror: |
|
fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
|
plt.plot (error, 'o-', label='Error over SS type') |
|
plt.legend() |
|
plt.ylabel ('SS content') |
|
plt.show() |
|
|
|
x=np.linspace (0, 7, 8) |
|
|
|
sslabels=['H','E','T','~','B','G','I','S'] |
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
|
|
|
ax.bar(x-0.15, cond, width=0.3, color='b', align='center') |
|
ax.bar(x+0.15, sscount, width=0.3, color='r', align='center') |
|
|
|
ax.set_ylim([0, 1]) |
|
|
|
plt.xticks(range(len(sslabels)), sslabels, size='medium') |
|
plt.legend (['GT','Prediction']) |
|
|
|
plt.ylabel ('SS content') |
|
plt.show() |
|
|
|
|
|
|
|
sscount=[] |
|
length = len (sec_structure) |
|
sscount.append (sec_structure_3state.count('H')/length) |
|
sscount.append (sec_structure_3state.count('E')/length) |
|
sscount.append (sec_structure_3state.count('~')/length) |
|
cond_p=[np.sum([cond[0],cond[5], cond[6]]), np.sum ([cond[1], cond[4]]), np.sum([cond[2],cond[3],cond[7]]) ] |
|
|
|
print ("cond 3type: ",cond_p) |
|
sscount=np.asarray (sscount) |
|
|
|
error3=np.abs(sscount-cond_p) |
|
print ("Abs error per 3-type SS structure type (C, H, E): ", error) |
|
|
|
if ploterror: |
|
fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
|
|
|
plt.plot (error3, 'o-', label='Error over SS type') |
|
plt.legend() |
|
plt.ylabel ('SS content') |
|
plt.show() |
|
|
|
|
|
x=np.linspace (0,2, 3) |
|
|
|
sslabels=['H','E', '~' ] |
|
|
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
|
|
|
|
|
ax.bar(x-0.15, cond_p, width=0.3, color='b', align='center') |
|
ax.bar(x+0.15, sscount, width=0.3, color='r', align='center') |
|
|
|
ax.set_ylim([0, 1]) |
|
|
|
plt.xticks(range(len(sslabels)), sslabels, size='medium') |
|
plt.legend (['GT','Prediction']) |
|
|
|
plt.ylabel ('SS content') |
|
plt.show() |
|
|
|
return error |
|
|
|
def get_DSSP_result (fname): |
|
pdb_list = [fname] |
|
|
|
|
|
p = PDBParser() |
|
for i in pdb_list: |
|
structure = p.get_structure(i, fname) |
|
|
|
model = structure[0] |
|
|
|
dssp = DSSP(model, fname, file_type='PDB' ) |
|
|
|
sequence = '' |
|
sec_structure = '' |
|
for z in range(len(dssp)): |
|
a_key = list(dssp.keys())[z] |
|
sequence += dssp[a_key][1] |
|
sec_structure += dssp[a_key][2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sec_structure = sec_structure.replace('-', '~') |
|
sec_structure_3state=sec_structure |
|
|
|
|
|
|
|
sec_structure_3state = sec_structure_3state.replace('H', 'H') |
|
sec_structure_3state = sec_structure_3state.replace('E', 'E') |
|
sec_structure_3state = sec_structure_3state.replace('T', '~') |
|
sec_structure_3state = sec_structure_3state.replace('~', '~') |
|
sec_structure_3state = sec_structure_3state.replace('B', 'E') |
|
sec_structure_3state = sec_structure_3state.replace('G', 'H') |
|
sec_structure_3state = sec_structure_3state.replace('I', 'H') |
|
sec_structure_3state = sec_structure_3state.replace('S', '~') |
|
return sec_structure,sec_structure_3state, sequence |
|
|
|
|
|
def string_diff (seq1, seq2): |
|
return sum(1 for a, b in zip(seq1, seq2) if a != b) + abs(len(seq1) - len(seq2)) |
|
|
|
|
|
|
|
|
|
|
|
import esm |
|
|
|
def decode_one_ems_token_rec(this_token, esm_alphabet): |
|
|
|
|
|
|
|
|
|
id_b=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] |
|
id_e=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] |
|
|
|
|
|
if len(id_e)==0: |
|
|
|
id_e=len(this_token) |
|
else: |
|
id_e=id_e[0] |
|
if len(id_b)==0: |
|
id_b=0 |
|
else: |
|
id_b=id_b[-1] |
|
|
|
this_seq = [] |
|
|
|
for ii in range(id_b+1,id_e,1): |
|
|
|
this_seq.append( |
|
esm_alphabet.get_tok(this_token[ii]) |
|
) |
|
|
|
this_seq = "".join(this_seq) |
|
|
|
|
|
|
|
|
|
return this_seq |
|
|
|
|
|
def decode_many_ems_token_rec(batch_tokens, esm_alphabet): |
|
rev_y_seq = [] |
|
for jj in range(len(batch_tokens)): |
|
|
|
this_seq = decode_one_ems_token_rec( |
|
batch_tokens[jj], esm_alphabet |
|
) |
|
rev_y_seq.append(this_seq) |
|
return rev_y_seq |
|
|
|
|
|
uncomm_idx_list = [0, 1, 2, 3, 24, 25, 26, 27, 28, 29, 30, 31, 32] |
|
|
|
|
|
def decode_one_ems_token_rec_for_folding( |
|
this_token, |
|
this_logits, |
|
esm_alphabet, |
|
esm_model): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
id_b_0=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] |
|
id_e_0=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] |
|
|
|
|
|
|
|
|
|
|
|
id_b = 0 |
|
|
|
if len(id_e_0)==0: |
|
id_e=len(this_token) |
|
else: |
|
id_e=id_e_0[0] |
|
|
|
if id_e<=id_b+1: |
|
if len(id_e_0)>1: |
|
id_e=id_e_0[1] |
|
else: |
|
id_e=len(this_token) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("start at: ", id_b) |
|
print("end at: ", id_e) |
|
|
|
|
|
use_logits = this_logits[id_b+1:id_e] |
|
use_logits[:,uncomm_idx_list]=-float('inf') |
|
use_token = use_logits.max(1).indices |
|
|
|
|
|
|
|
this_seq = [] |
|
|
|
|
|
for ii in range(len(use_token)): |
|
|
|
|
|
|
|
this_seq.append( |
|
esm_alphabet.get_tok(use_token[ii]) |
|
) |
|
|
|
this_seq = "".join(this_seq) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return this_seq |
|
|
|
|
|
def decode_many_ems_token_rec_for_folding( |
|
batch_tokens, |
|
batch_logits, |
|
esm_alphabet, |
|
esm_model): |
|
|
|
rev_y_seq = [] |
|
for jj in range(len(batch_tokens)): |
|
|
|
this_seq = decode_one_ems_token_rec_for_folding( |
|
batch_tokens[jj], |
|
batch_logits[jj], |
|
esm_alphabet, |
|
esm_model, |
|
) |
|
rev_y_seq.append(this_seq) |
|
return rev_y_seq |
|
|
|
|
|
def convert_into_logits(esm_model, result): |
|
repre=rearrange( |
|
result, |
|
'b l c -> b c l' |
|
) |
|
with torch.no_grad(): |
|
logits=esm_model.lm_head(repre) |
|
|
|
return logits |
|
|
|
|
|
def convert_into_tokens(model, result, pLM_Model_Name): |
|
if pLM_Model_Name=='esm2_t33_650M_UR50D' \ |
|
or pLM_Model_Name=='esm2_t36_3B_UR50D' \ |
|
or pLM_Model_Name=='esm2_t30_150M_UR50D' \ |
|
or pLM_Model_Name=='esm2_t12_35M_UR50D' : |
|
|
|
repre=rearrange( |
|
result, |
|
'b c l -> b l c' |
|
) |
|
with torch.no_grad(): |
|
logits=model.lm_head(repre) |
|
|
|
tokens=logits.max(2).indices |
|
|
|
else: |
|
print("pLM_Model is not defined...") |
|
return tokens,logits |
|
|
|
def convert_into_tokens_using_prob(prob_result, pLM_Model_Name): |
|
if pLM_Model_Name=='esm2_t33_650M_UR50D' \ |
|
or pLM_Model_Name=='esm2_t36_3B_UR50D' \ |
|
or pLM_Model_Name=='esm2_t30_150M_UR50D' \ |
|
or pLM_Model_Name=='esm2_t12_35M_UR50D' : |
|
|
|
repre=rearrange( |
|
prob_result, |
|
'b c l -> b l c' |
|
) |
|
|
|
|
|
logits = repre |
|
|
|
tokens=logits.max(2).indices |
|
|
|
else: |
|
print("pLM_Model is not defined...") |
|
return tokens,logits |
|
|
|
|
|
|
|
def read_mask_from_input( |
|
|
|
|
|
|
|
tokenized_data=None, |
|
mask_value=None, |
|
seq_data=None, |
|
max_seq_length=None, |
|
): |
|
|
|
|
|
|
|
if seq_data!=None: |
|
|
|
n_seq = len(seq_data) |
|
mask = torch.zeros(n_seq, max_seq_length) |
|
for ii in range(n_seq): |
|
this_len = len(seq_data[ii]) |
|
mask[ii,1:1+this_len]=1 |
|
mask = mask==1 |
|
|
|
elif tokenized_data!=None: |
|
n_seq = len(tokenized_data) |
|
mask = tokenized_data!=mask_value |
|
|
|
for ii in range(n_seq): |
|
|
|
id_1 = (mask[ii]==True).nonzero(as_tuple=True)[0] |
|
|
|
|
|
mask[ii,1:id_1[0]]=True |
|
|
|
return mask |
|
|
|
|
|
def read_one_len_from_padding_vec( |
|
in_np_array, |
|
padding_val=0.0, |
|
): |
|
mask = in_np_array!=padding_val |
|
id_list_all_1 = mask.nonzero()[0] |
|
vec_len = id_list_all_1[-1]+1 |
|
|
|
return vec_len |
|
|
|
|
|
|
|
def decode_one_ems_token_rec_for_folding_with_mask( |
|
this_token, |
|
this_logits, |
|
esm_alphabet, |
|
esm_model, |
|
this_mask, |
|
): |
|
|
|
|
|
|
|
|
|
use_logits = this_logits |
|
use_logits[:,uncomm_idx_list]=-float('inf') |
|
use_token = use_logits.max(1).indices |
|
|
|
print(use_token) |
|
use_token = use_token[this_mask==True] |
|
|
|
|
|
this_seq = [] |
|
|
|
|
|
for ii in range(len(use_token)): |
|
|
|
|
|
|
|
this_seq.append( |
|
esm_alphabet.get_tok(use_token[ii]) |
|
) |
|
|
|
this_seq = "".join(this_seq) |
|
|
|
return this_seq |
|
|
|
def decode_many_ems_token_rec_for_folding_with_mask( |
|
batch_tokens, |
|
batch_logits, |
|
esm_alphabet, |
|
esm_model, |
|
mask): |
|
|
|
rev_y_seq = [] |
|
for jj in range(len(batch_tokens)): |
|
|
|
this_seq = decode_one_ems_token_rec_for_folding_with_mask( |
|
batch_tokens[jj], |
|
batch_logits[jj], |
|
esm_alphabet, |
|
esm_model, |
|
mask[jj] |
|
) |
|
rev_y_seq.append(this_seq) |
|
return rev_y_seq |
|
|
|
|
|
|
|
|
|
from scipy import interpolate |
|
|
|
def interpolate_and_resample_ForcPath(y0,seq_len1): |
|
seq_len0=len(y0)-1 |
|
x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0) |
|
f=interpolate.interp1d(x0,y0) |
|
|
|
x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1) |
|
y1=f(x1) |
|
|
|
resu = {} |
|
resu['y1']=y1 |
|
resu['x1']=x1 |
|
resu['x0']=x0 |
|
return resu |
|
|
|
def mix_two_ForcPath(y0,y1,seq_len2): |
|
seq_len0=len(y0)-1 |
|
x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0) |
|
seq_len1=len(y1)-1 |
|
x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1) |
|
f0=interpolate.interp1d(x0,y0) |
|
f1=interpolate.interp1d(x1,y1) |
|
|
|
x2=np.arange(0., 1.+1./seq_len2, 1./seq_len2) |
|
y2=(f0(x2)+f1(x2))/1. |
|
|
|
resu={} |
|
resu['y2']=y2 |
|
resu['x2']=x2 |
|
resu['x1']=x1 |
|
resu['x0']=x0 |
|
return resu |
|
|
|
|
|
|
|
|
|
import esm |
|
|
|
def load_in_pLM(pLM_Model_Name,device): |
|
|
|
|
|
if pLM_Model_Name=='trivial': |
|
pLM_Model=None |
|
esm_alphabet=None |
|
len_toks=0 |
|
esm_layer=0 |
|
|
|
elif pLM_Model_Name=='esm2_t33_650M_UR50D': |
|
|
|
esm_layer=33 |
|
pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
pLM_Model.eval() |
|
pLM_Model. to(device) |
|
|
|
elif pLM_Model_Name=='esm2_t36_3B_UR50D': |
|
|
|
esm_layer=36 |
|
pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
pLM_Model.eval() |
|
pLM_Model. to(device) |
|
|
|
elif pLM_Model_Name=='esm2_t30_150M_UR50D': |
|
|
|
esm_layer=30 |
|
pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
pLM_Model.eval() |
|
pLM_Model. to(device) |
|
|
|
elif pLM_Model_Name=='esm2_t12_35M_UR50D': |
|
|
|
esm_layer=12 |
|
pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() |
|
len_toks=len(esm_alphabet.all_toks) |
|
pLM_Model.eval() |
|
pLM_Model. to(device) |
|
|
|
else: |
|
print("pLM model is missing...") |
|
|
|
return pLM_Model, esm_alphabet, esm_layer, len_toks |