import os |
from scipy.interpolate import CubicSpline, PchipInterpolator, Akima1DInterpolator |
import numpy as np |
import math |
import matplotlib.pyplot as plt |
from Bio.PDB import PDBParser |
from Bio.PDB.DSSP import DSSP |
from Bio.PDB import PDBList |
import torch |
from einops import rearrange |
import esm |
def create_path(this_path): |
if not os.path.exists(this_path): |
print('Creating the given path...') |
os.mkdir (this_path) |
path_stat = 1 |
print('Done.') |
else: |
print('The given path already exists!') |
path_stat = 2 |
return path_stat |
def params (model): |
pytorch_total_params = sum(p.numel() for p in model.parameters()) |
pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
print ("Total parameters: ", pytorch_total_params," trainable parameters: ", pytorch_total_params_trainable) |
def prepare_UNet_keys(write_dict): |
Full_Keys=['dim', 'text_embed_dim', 'num_resnet_blocks', 'cond_dim', 'num_image_tokens', 'num_time_tokens', 'learned_sinu_pos_emb_dim', 'out_dim', 'dim_mults', 'cond_images_channels', 'channels', 'channels_out', 'attn_dim_head', 'attn_heads', 'ff_mult', 'lowres_cond', 'layer_attns', 'layer_attns_depth', 'layer_attns_add_text_cond', 'attend_at_middle', 'layer_cross_attns', 'use_linear_attn', 'use_linear_cross_attn', 'cond_on_text', 'max_text_len', 'init_dim', 'resnet_groups', 'init_conv_kernel_size', 'init_cross_embed', 'init_cross_embed_kernel_sizes', 'cross_embed_downsample', 'cross_embed_downsample_kernel_sizes', 'attn_pool_text', 'attn_pool_num_latents', 'dropout', 'memory_efficient', 'init_conv_to_final_conv_residual', 'use_global_context_attn', 'scale_skip_connection', 'final_resnet_block', 'final_conv_kernel_size', 'cosine_sim_attn', 'self_cond', 'combine_upsample_fmaps', 'pixel_shuffle_upsample', 'beginning_and_final_conv_present'] |
PKeys={} |
for key in Full_Keys: |
PKeys[key]=None |
for write_key in write_dict.keys(): |
if write_key in PKeys.keys(): |
PKeys[write_key]=write_dict[write_key] |
else: |
print("Wrong key found: ", write_key) |
return PKeys |
def prepare_ModelB_keys(write_dict): |
Full_Keys=['timesteps', 'dim', 'pred_dim', 'loss_type', 'elucidated', 'padding_idx', 'cond_dim', 'text_embed_dim', 'input_tokens', 'sequence_embed', 'embed_dim_position', 'max_text_len', 'cond_images_channels', 'max_length', 'device'] |
PKeys={} |
for key in Full_Keys: |
PKeys[key]=None |
for write_key in write_dict.keys(): |
if write_key in PKeys.keys(): |
PKeys[write_key]=write_dict[write_key] |
else: |
print("Wrong key found: ", write_key) |
return PKeys |
def modify_keys(old_dict,write_dict): |
new_dict = old_dict.copy() |
for w_key in write_dict.keys(): |
if w_key in old_dict.keys(): |
new_dict[w_key]=write_dict[w_key] |
else: |
print("Alien key found: ", w_key) |
return new_dict |
def mixing_two_FORCE_for_AA_Len(NGap1,Force1,NGap2,Force2,LenAA,mix_fac): |
N = np.amax([len(NGap1), len(NGap2)]) |
N_Base = math.ceil(N*2) |
fun_PI_0 = PchipInterpolator(NGap1,Force1) |
fun_PI_1 = PchipInterpolator(NGap2,Force2) |
xx=np.linspace(0,1,N_Base) |
yy=fun_PI_0(xx)*mix_fac+fun_PI_1(xx)*(1-mix_fac) |
fun_PI = PchipInterpolator(xx,yy) |
x1=np.linspace(0,1,LenAA+1) |
y1=fun_PI(x1) |
return fun_PI, x1, y1 |
def get_Model_A_error (fname, cond, plotit=True, ploterror=False): |
sec_structure,sec_structure_3state, sequence=get_DSSP_result (fname) |
sscount=[] |
length = len (sec_structure) |
sscount.append (sec_structure.count('H')/length) |
sscount.append (sec_structure.count('E')/length) |
sscount.append (sec_structure.count('T')/length) |
sscount.append (sec_structure.count('~')/length) |
sscount.append (sec_structure.count('B')/length) |
sscount.append (sec_structure.count('G')/length) |
sscount.append (sec_structure.count('I')/length) |
sscount.append (sec_structure.count('S')/length) |
sscount=np.asarray (sscount) |
error=np.abs(sscount-cond) |
print ("Abs error per SS structure type (H, E, T, ~, B, G, I S): ", error) |
if ploterror: |
fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
plt.plot (error, 'o-', label='Error over SS type') |
plt.legend() |
plt.ylabel ('SS content') |
plt.show() |
x=np.linspace (0, 7, 8) |
sslabels=['H','E','T','~','B','G','I','S'] |
fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
ax.bar(x-0.15, cond, width=0.3, color='b', align='center') |
ax.bar(x+0.15, sscount, width=0.3, color='r', align='center') |
ax.set_ylim([0, 1]) |
plt.xticks(range(len(sslabels)), sslabels, size='medium') |
plt.legend (['GT','Prediction']) |
plt.ylabel ('SS content') |
plt.show() |
sscount=[] |
length = len (sec_structure) |
sscount.append (sec_structure_3state.count('H')/length) |
sscount.append (sec_structure_3state.count('E')/length) |
sscount.append (sec_structure_3state.count('~')/length) |
cond_p=[np.sum([cond[0],cond[5], cond[6]]), np.sum ([cond[1], cond[4]]), np.sum([cond[2],cond[3],cond[7]]) ] |
print ("cond 3type: ",cond_p) |
sscount=np.asarray (sscount) |
error3=np.abs(sscount-cond_p) |
print ("Abs error per 3-type SS structure type (C, H, E): ", error) |
if ploterror: |
fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
plt.plot (error3, 'o-', label='Error over SS type') |
plt.legend() |
plt.ylabel ('SS content') |
plt.show() |
x=np.linspace (0,2, 3) |
sslabels=['H','E', '~' ] |
fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
ax.bar(x-0.15, cond_p, width=0.3, color='b', align='center') |
ax.bar(x+0.15, sscount, width=0.3, color='r', align='center') |
ax.set_ylim([0, 1]) |
plt.xticks(range(len(sslabels)), sslabels, size='medium') |
plt.legend (['GT','Prediction']) |
plt.ylabel ('SS content') |
plt.show() |
return error |
def get_DSSP_result (fname): |
pdb_list = [fname] |
p = PDBParser() |
for i in pdb_list: |
structure = p.get_structure(i, fname) |
model = structure[0] |
dssp = DSSP(model, fname, file_type='PDB' ) |
sequence = '' |
sec_structure = '' |
for z in range(len(dssp)): |
a_key = list(dssp.keys())[z] |
sequence += dssp[a_key][1] |
sec_structure += dssp[a_key][2] |
sec_structure = sec_structure.replace('-', '~') |
sec_structure_3state=sec_structure |
sec_structure_3state = sec_structure_3state.replace('H', 'H') |
sec_structure_3state = sec_structure_3state.replace('E', 'E') |
sec_structure_3state = sec_structure_3state.replace('T', '~') |
sec_structure_3state = sec_structure_3state.replace('~', '~') |
sec_structure_3state = sec_structure_3state.replace('B', 'E') |
sec_structure_3state = sec_structure_3state.replace('G', 'H') |
sec_structure_3state = sec_structure_3state.replace('I', 'H') |
sec_structure_3state = sec_structure_3state.replace('S', '~') |
return sec_structure,sec_structure_3state, sequence |
def string_diff (seq1, seq2): |
return sum(1 for a, b in zip(seq1, seq2) if a != b) + abs(len(seq1) - len(seq2)) |
import esm |
def decode_one_ems_token_rec(this_token, esm_alphabet): |
id_b=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] |
id_e=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] |
if len(id_e)==0: |
id_e=len(this_token) |
else: |
id_e=id_e[0] |
if len(id_b)==0: |
id_b=0 |
else: |
id_b=id_b[-1] |
this_seq = [] |
for ii in range(id_b+1,id_e,1): |
this_seq.append( |
esm_alphabet.get_tok(this_token[ii]) |
) |
this_seq = "".join(this_seq) |
return this_seq |
def decode_many_ems_token_rec(batch_tokens, esm_alphabet): |
rev_y_seq = [] |
for jj in range(len(batch_tokens)): |
this_seq = decode_one_ems_token_rec( |
batch_tokens[jj], esm_alphabet |
) |
rev_y_seq.append(this_seq) |
return rev_y_seq |
uncomm_idx_list = [0, 1, 2, 3, 24, 25, 26, 27, 28, 29, 30, 31, 32] |
def decode_one_ems_token_rec_for_folding( |
this_token, |
this_logits, |
esm_alphabet, |
esm_model): |
id_b_0=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] |
id_e_0=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] |
id_b = 0 |
if len(id_e_0)==0: |
id_e=len(this_token) |
else: |
id_e=id_e_0[0] |
if id_e<=id_b+1: |
if len(id_e_0)>1: |
id_e=id_e_0[1] |
else: |
id_e=len(this_token) |
print("start at: ", id_b) |
print("end at: ", id_e) |
use_logits = this_logits[id_b+1:id_e] |
use_logits[:,uncomm_idx_list]=-float('inf') |
use_token = use_logits.max(1).indices |
this_seq = [] |
for ii in range(len(use_token)): |
this_seq.append( |
esm_alphabet.get_tok(use_token[ii]) |
) |
this_seq = "".join(this_seq) |
return this_seq |
def decode_many_ems_token_rec_for_folding( |
batch_tokens, |
batch_logits, |
esm_alphabet, |
esm_model): |
rev_y_seq = [] |
for jj in range(len(batch_tokens)): |
this_seq = decode_one_ems_token_rec_for_folding( |
batch_tokens[jj], |
batch_logits[jj], |
esm_alphabet, |
esm_model, |
) |
rev_y_seq.append(this_seq) |
return rev_y_seq |
def convert_into_logits(esm_model, result): |
repre=rearrange( |
result, |
'b l c -> b c l' |
) |
with torch.no_grad(): |
logits=esm_model.lm_head(repre) |
return logits |
def convert_into_tokens(model, result, pLM_Model_Name): |
if pLM_Model_Name=='esm2_t33_650M_UR50D' \ |
or pLM_Model_Name=='esm2_t36_3B_UR50D' \ |
or pLM_Model_Name=='esm2_t30_150M_UR50D' \ |
or pLM_Model_Name=='esm2_t12_35M_UR50D' : |
repre=rearrange( |
result, |
'b c l -> b l c' |
) |
with torch.no_grad(): |
logits=model.lm_head(repre) |
tokens=logits.max(2).indices |
else: |
print("pLM_Model is not defined...") |
return tokens,logits |
def convert_into_tokens_using_prob(prob_result, pLM_Model_Name): |
if pLM_Model_Name=='esm2_t33_650M_UR50D' \ |
or pLM_Model_Name=='esm2_t36_3B_UR50D' \ |
or pLM_Model_Name=='esm2_t30_150M_UR50D' \ |
or pLM_Model_Name=='esm2_t12_35M_UR50D' : |
repre=rearrange( |
prob_result, |
'b c l -> b l c' |
) |
logits = repre |
tokens=logits.max(2).indices |
else: |
print("pLM_Model is not defined...") |
return tokens,logits |
def read_mask_from_input( |
tokenized_data=None, |
mask_value=None, |
seq_data=None, |
max_seq_length=None, |
): |
if seq_data!=None: |
n_seq = len(seq_data) |
mask = torch.zeros(n_seq, max_seq_length) |
for ii in range(n_seq): |
this_len = len(seq_data[ii]) |
mask[ii,1:1+this_len]=1 |
mask = mask==1 |
elif tokenized_data!=None: |
n_seq = len(tokenized_data) |
mask = tokenized_data!=mask_value |
for ii in range(n_seq): |
id_1 = (mask[ii]==True).nonzero(as_tuple=True)[0] |
mask[ii,1:id_1[0]]=True |
return mask |
def read_one_len_from_padding_vec( |
in_np_array, |
padding_val=0.0, |
): |
mask = in_np_array!=padding_val |
id_list_all_1 = mask.nonzero()[0] |
vec_len = id_list_all_1[-1]+1 |
return vec_len |
def decode_one_ems_token_rec_for_folding_with_mask( |
this_token, |
this_logits, |
esm_alphabet, |
esm_model, |
this_mask, |
): |
use_logits = this_logits |
use_logits[:,uncomm_idx_list]=-float('inf') |
use_token = use_logits.max(1).indices |
print(use_token) |
use_token = use_token[this_mask==True] |
this_seq = [] |
for ii in range(len(use_token)): |
this_seq.append( |
esm_alphabet.get_tok(use_token[ii]) |
) |
this_seq = "".join(this_seq) |
return this_seq |
def decode_many_ems_token_rec_for_folding_with_mask( |
batch_tokens, |
batch_logits, |
esm_alphabet, |
esm_model, |
mask): |
rev_y_seq = [] |
for jj in range(len(batch_tokens)): |
this_seq = decode_one_ems_token_rec_for_folding_with_mask( |
batch_tokens[jj], |
batch_logits[jj], |
esm_alphabet, |
esm_model, |
mask[jj] |
) |
rev_y_seq.append(this_seq) |
return rev_y_seq |
from scipy import interpolate |
def interpolate_and_resample_ForcPath(y0,seq_len1): |
seq_len0=len(y0)-1 |
x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0) |
f=interpolate.interp1d(x0,y0) |
x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1) |
y1=f(x1) |
resu = {} |
resu['y1']=y1 |
resu['x1']=x1 |
resu['x0']=x0 |
return resu |
def mix_two_ForcPath(y0,y1,seq_len2): |
seq_len0=len(y0)-1 |
x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0) |
seq_len1=len(y1)-1 |
x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1) |
f0=interpolate.interp1d(x0,y0) |
f1=interpolate.interp1d(x1,y1) |
x2=np.arange(0., 1.+1./seq_len2, 1./seq_len2) |
y2=(f0(x2)+f1(x2))/1. |
resu={} |
resu['y2']=y2 |
resu['x2']=x2 |
resu['x1']=x1 |
resu['x0']=x0 |
return resu |
import esm |
def load_in_pLM(pLM_Model_Name,device): |
if pLM_Model_Name=='trivial': |
pLM_Model=None |
esm_alphabet=None |
len_toks=0 |
esm_layer=0 |
elif pLM_Model_Name=='esm2_t33_650M_UR50D': |
esm_layer=33 |
pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
len_toks=len(esm_alphabet.all_toks) |
pLM_Model.eval() |
pLM_Model. to(device) |
elif pLM_Model_Name=='esm2_t36_3B_UR50D': |
esm_layer=36 |
pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() |
len_toks=len(esm_alphabet.all_toks) |
pLM_Model.eval() |
pLM_Model. to(device) |
elif pLM_Model_Name=='esm2_t30_150M_UR50D': |
esm_layer=30 |
pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() |
len_toks=len(esm_alphabet.all_toks) |
pLM_Model.eval() |
pLM_Model. to(device) |
elif pLM_Model_Name=='esm2_t12_35M_UR50D': |
esm_layer=12 |
pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() |
len_toks=len(esm_alphabet.all_toks) |
pLM_Model.eval() |
pLM_Model. to(device) |
else: |
print("pLM model is missing...") |
return pLM_Model, esm_alphabet, esm_layer, len_toks |