|
|
|
from fuson_plm.utils.embedding import get_esm_embeddings, load_esm2_type, redump_pickle_dictionary, load_prott5, get_prott5_embeddings |
|
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy |
|
from fuson_plm.utils.data_cleaning import find_invalid_chars |
|
from fuson_plm.utils.constants import VALID_AAS |
|
from fuson_plm.training.model import FusOnpLM |
|
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel |
|
import logging |
|
import torch |
|
import pickle |
|
import os |
|
import pandas as pd |
|
import numpy as np |
|
|
|
def validate_sequence_col(df, seq_col): |
|
|
|
if seq_col not in list(df.columns): |
|
raise Exception("Error: provided sequence column does not exist in the input dataframe") |
|
|
|
|
|
df['invalid_chars'] = df[seq_col].apply(lambda x: find_invalid_chars(x, VALID_AAS)) |
|
all_invalid_chars = set().union(*df['invalid_chars']) |
|
df = df.drop(columns=['invalid_chars']) |
|
if len(all_invalid_chars)>0: |
|
raise Exception(f"Error: invalid characters {all_invalid_chars} found in the sequence column") |
|
|
|
|
|
sequences = df[seq_col] |
|
if len(set(sequences))<len(sequences): log_update("\tWARNING: input data has duplicate sequences") |
|
|
|
def load_fuson_model(ckpt_path): |
|
|
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
model = AutoModel.from_pretrained(ckpt_path) |
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_path) |
|
|
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
return model, tokenizer, device |
|
|
|
def get_fuson_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False, max_length=2000): |
|
|
|
if savepath is not None: |
|
if savepath[-4::] != '.pkl': savepath += '.pkl' |
|
|
|
if print_updates: log_update(f"Dataset contains {len(sequences)} sequences.") |
|
|
|
|
|
max_seq_len = max([len(s) for s in sequences]) |
|
if max_length is None: max_length=max_seq_len+2 |
|
|
|
|
|
embedding_dict = {} |
|
|
|
for i in range(len(sequences)): |
|
sequence = sequences[i] |
|
|
|
with torch.no_grad(): |
|
|
|
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=max_length) |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
outputs = model(**inputs) |
|
|
|
embedding = outputs.last_hidden_state |
|
|
|
embedding = embedding.squeeze(0) |
|
|
|
embedding = embedding[1:-1, :] |
|
|
|
|
|
embedding = embedding.cpu().numpy() |
|
|
|
|
|
if average: |
|
embedding = embedding.mean(0) |
|
|
|
|
|
embedding_dict[sequence] = embedding |
|
|
|
|
|
if not(savepath is None) and not(save_at_end): |
|
with open(savepath, 'ab+') as f: |
|
d = {sequence: embedding} |
|
pickle.dump(d, f) |
|
|
|
|
|
if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...") |
|
|
|
|
|
if not(savepath is None): |
|
|
|
if save_at_end: |
|
with open(savepath, 'wb') as f: |
|
pickle.dump(embedding_dict, f) |
|
|
|
else: |
|
redump_pickle_dictionary(savepath) |
|
|
|
def embed_dataset(path_to_file, path_to_output, seq_col='aa_seq', model_type='fuson_plm', fuson_ckpt_path = None, average=True, overwrite=True, print_updates=False,max_length=2000): |
|
|
|
if os.path.exists(path_to_output): |
|
if overwrite: |
|
log_update(f"WARNING: these embeddings may already exist at {path_to_output} and will be overwritten") |
|
else: |
|
log_update(f"WARNING: these embeddings may already exist at {path_to_output}. Skipping.") |
|
return None |
|
|
|
dataset = pd.read_csv(path_to_file) |
|
|
|
validate_sequence_col(dataset, seq_col) |
|
|
|
sequences = dataset[seq_col].unique().tolist() |
|
|
|
|
|
if model_type=='fuson_plm': |
|
if not(os.path.exists(fuson_ckpt_path)): raise Exception("FusOn-pLM ckpt path does not exist") |
|
|
|
|
|
try: |
|
model, tokenizer, device = load_fuson_model(fuson_ckpt_path) |
|
except: |
|
raise Exception(f"Could not load FusOn-pLM from {fuson_ckpt_path}") |
|
|
|
|
|
try: |
|
get_fuson_embeddings(model, tokenizer, sequences, device, average=average, |
|
print_updates=print_updates, savepath=path_to_output, save_at_end=False, |
|
max_length=max_length) |
|
except: |
|
raise Exception("Could not generate FusOn-pLM embeddings") |
|
|
|
if model_type=='esm2_t33_650M_UR50D': |
|
|
|
try: |
|
model, tokenizer, device = load_esm2_type(model_type) |
|
except: |
|
raise Exception(f"Could not load {model_type}") |
|
|
|
try: |
|
get_esm_embeddings(model, tokenizer, sequences, device, average=average, |
|
print_updates=print_updates, savepath=path_to_output, save_at_end=False, |
|
max_length=max_length) |
|
except: |
|
raise Exception(f"Could not generate {model_type} embeddings") |
|
|
|
if model_type=="prot_t5_xl_half_uniref50_enc": |
|
|
|
try: |
|
model, tokenizer, device = load_prott5() |
|
except: |
|
raise Exception(f"Could not load {model_type}") |
|
|
|
try: |
|
get_prott5_embeddings(model, tokenizer, sequences, device, average=average, |
|
print_updates=print_updates, savepath=path_to_output, save_at_end=False, |
|
max_length=max_length) |
|
except: |
|
raise Exception(f"Could not generate {model_type} embeddings") |
|
|
|
|
|
def embed_dataset_for_benchmark(fuson_ckpts=None, input_data_path=None, input_fname=None, average=True, seq_col='seq', benchmark_fusonplm=False, benchmark_esm=False, benchmark_fo_puncta_ml=False, benchmark_prott5=False, overwrite=False,max_length=None): |
|
|
|
os.makedirs('embeddings',exist_ok=True) |
|
|
|
|
|
emb_type_tag ='average' if average else '2D' |
|
|
|
all_embedding_paths = dict() |
|
|
|
|
|
if benchmark_fusonplm: |
|
os.makedirs('embeddings/fuson_plm',exist_ok=True) |
|
|
|
log_update(f"\nMaking Fuson-PLM embeddings") |
|
|
|
if type(fuson_ckpts)==dict: |
|
for model_name, epoch_list in fuson_ckpts.items(): |
|
os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True) |
|
for epoch in epoch_list: |
|
|
|
fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}' |
|
if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}") |
|
|
|
|
|
embedding_output_dir = f'embeddings/fuson_plm/{model_name}/epoch{epoch}' |
|
embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl' |
|
os.makedirs(embedding_output_dir,exist_ok=True) |
|
|
|
|
|
model_type = 'fuson_plm' |
|
all_embedding_paths[embedding_output_path] = { |
|
'model_type': model_type, |
|
'model': model_name, |
|
'epoch': epoch |
|
} |
|
|
|
|
|
log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...") |
|
embed_dataset(input_data_path, embedding_output_path, |
|
seq_col=seq_col, model_type=model_type, |
|
fuson_ckpt_path=fuson_ckpt_path, average=average, |
|
overwrite=overwrite,print_updates=True, |
|
max_length=max_length) |
|
elif fuson_ckpts=="FusOn-pLM": |
|
model_name = "best" |
|
os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True) |
|
|
|
|
|
fuson_ckpt_path = "../../.." |
|
if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}") |
|
|
|
|
|
embedding_output_dir = f'embeddings/fuson_plm/{model_name}' |
|
embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl' |
|
os.makedirs(embedding_output_dir,exist_ok=True) |
|
|
|
|
|
model_type = 'fuson_plm' |
|
all_embedding_paths[embedding_output_path] = { |
|
'model_type': model_type, |
|
'model': model_name, |
|
'epoch': None |
|
} |
|
|
|
|
|
log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...") |
|
embed_dataset(input_data_path, embedding_output_path, |
|
seq_col=seq_col, model_type=model_type, |
|
fuson_ckpt_path=fuson_ckpt_path, average=average, |
|
overwrite=overwrite,print_updates=True, |
|
max_length=max_length) |
|
else: |
|
raise Exception(f"Error. fuson_ckpts should be a dict or str") |
|
|
|
|
|
if benchmark_esm: |
|
os.makedirs('embeddings/esm2_t33_650M_UR50D',exist_ok=True) |
|
|
|
|
|
embedding_output_path = f'embeddings/esm2_t33_650M_UR50D/{input_fname}_{emb_type_tag}_embeddings.pkl' |
|
|
|
|
|
model_type = 'esm2_t33_650M_UR50D' |
|
all_embedding_paths[embedding_output_path] = { |
|
'model_type': model_type, |
|
'model': model_type, |
|
'epoch': np.nan |
|
} |
|
|
|
log_update(f"\nMaking ESM-2-650M embeddings for {input_data_path} and saving results to {embedding_output_path}...") |
|
embed_dataset(input_data_path, embedding_output_path, |
|
seq_col=seq_col, model_type=model_type, |
|
fuson_ckpt_path = None, average=average, |
|
overwrite=overwrite,print_updates=True, |
|
max_length=max_length) |
|
|
|
if benchmark_prott5: |
|
os.makedirs('embeddings/prot_t5_xl_half_uniref50_enc',exist_ok=True) |
|
|
|
|
|
embedding_output_path = f'embeddings/prot_t5_xl_half_uniref50_enc/{input_fname}_{emb_type_tag}_embeddings.pkl' |
|
|
|
|
|
model_type = 'prot_t5_xl_half_uniref50_enc' |
|
all_embedding_paths[embedding_output_path] = { |
|
'model_type': model_type, |
|
'model': model_type, |
|
'epoch': np.nan |
|
} |
|
|
|
log_update(f"\nMaking ProtT5-XL-UniRef50 embeddings for {input_data_path} and saving results to {embedding_output_path}...") |
|
embed_dataset(input_data_path, embedding_output_path, |
|
seq_col=seq_col, model_type=model_type, |
|
fuson_ckpt_path = None, average=average, |
|
overwrite=overwrite,print_updates=True, |
|
max_length=max_length) |
|
|
|
if benchmark_fo_puncta_ml: |
|
embedding_output_path =f'FOdb_physicochemical_embeddings.pkl' |
|
|
|
all_embedding_paths[embedding_output_path] = { |
|
'model_type': 'fo_puncta_ml', |
|
'model': 'fo_puncta_ml', |
|
'epoch': np.nan |
|
} |
|
|
|
return all_embedding_paths |