svincoff's picture
dependencies and embedding_exploration benchmark
c43fbc6
# Python file for making embeddings from a FusOn-pLM model for any dataset
from fuson_plm.utils.embedding import get_esm_embeddings, load_esm2_type, redump_pickle_dictionary, load_prott5, get_prott5_embeddings
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy
from fuson_plm.utils.data_cleaning import find_invalid_chars
from fuson_plm.utils.constants import VALID_AAS
from fuson_plm.training.model import FusOnpLM
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
import logging
import torch
import pickle
import os
import pandas as pd
import numpy as np
def validate_sequence_col(df, seq_col):
# if column isn't there, error
if seq_col not in list(df.columns):
raise Exception("Error: provided sequence column does not exist in the input dataframe")
# if column contains invalid characters, error
df['invalid_chars'] = df[seq_col].apply(lambda x: find_invalid_chars(x, VALID_AAS))
all_invalid_chars = set().union(*df['invalid_chars'])
df = df.drop(columns=['invalid_chars'])
if len(all_invalid_chars)>0:
raise Exception(f"Error: invalid characters {all_invalid_chars} found in the sequence column")
# make sure there are no duplicates
sequences = df[seq_col]
if len(set(sequences))<len(sequences): log_update("\tWARNING: input data has duplicate sequences")
def load_fuson_model(ckpt_path):
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load model
model = AutoModel.from_pretrained(ckpt_path) # initialize model
tokenizer = AutoTokenizer.from_pretrained(ckpt_path) # initialize tokenizer
# Model to device and in eval mode
model.to(device)
model.eval() # disables dropout for deterministic results
return model, tokenizer, device
def get_fuson_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False, max_length=2000):
# Correct save path to pickle if necessary
if savepath is not None:
if savepath[-4::] != '.pkl': savepath += '.pkl'
if print_updates: log_update(f"Dataset contains {len(sequences)} sequences.")
# If no max length was passed, just set it to the maximum in the dataset
max_seq_len = max([len(s) for s in sequences])
if max_length is None: max_length=max_seq_len+2 # add 2 for BOS, EOS
# Initialize an empty dict to store the ESM embeddings
embedding_dict = {}
# Iterate through the seqs
for i in range(len(sequences)):
sequence = sequences[i]
# Get the embeddings
with torch.no_grad():
# Tokenize the input sequence
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=max_length)
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
# The embeddings are in the last_hidden_state tensor
embedding = outputs.last_hidden_state
# remove extra dimension
embedding = embedding.squeeze(0)
# remove BOS and EOS tokens
embedding = embedding[1:-1, :]
# Convert embeddings to numpy array (if needed)
embedding = embedding.cpu().numpy()
# Average (if necessary)
if average:
embedding = embedding.mean(0)
# Add to dictionary
embedding_dict[sequence] = embedding
# Save individual embedding (if necessary)
if not(savepath is None) and not(save_at_end):
with open(savepath, 'ab+') as f:
d = {sequence: embedding}
pickle.dump(d, f)
# Print update (if necessary)
if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
# Dump all at once at the end (if necessary)
if not(savepath is None):
# If saving for the first time, just dump it
if save_at_end:
with open(savepath, 'wb') as f:
pickle.dump(embedding_dict, f)
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
else:
redump_pickle_dictionary(savepath)
def embed_dataset(path_to_file, path_to_output, seq_col='aa_seq', model_type='fuson_plm', fuson_ckpt_path = None, average=True, overwrite=True, print_updates=False,max_length=2000):
# Make sure we aren't overwriting pre-existing embeddings
if os.path.exists(path_to_output):
if overwrite:
log_update(f"WARNING: these embeddings may already exist at {path_to_output} and will be overwritten")
else:
log_update(f"WARNING: these embeddings may already exist at {path_to_output}. Skipping.")
return None
dataset = pd.read_csv(path_to_file)
# Make sure the sequence column is valid
validate_sequence_col(dataset, seq_col)
sequences = dataset[seq_col].unique().tolist() # ensure all entries are unique
### If FusOn-pLM: make fusion embeddings
if model_type=='fuson_plm':
if not(os.path.exists(fuson_ckpt_path)): raise Exception("FusOn-pLM ckpt path does not exist")
# Load model
try:
model, tokenizer, device = load_fuson_model(fuson_ckpt_path)
except:
raise Exception(f"Could not load FusOn-pLM from {fuson_ckpt_path}")
# Generate embeddigns
try:
get_fuson_embeddings(model, tokenizer, sequences, device, average=average,
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
max_length=max_length)
except:
raise Exception("Could not generate FusOn-pLM embeddings")
if model_type=='esm2_t33_650M_UR50D':
# Load model
try:
model, tokenizer, device = load_esm2_type(model_type)
except:
raise Exception(f"Could not load {model_type}")
# Generate embeddings
try:
get_esm_embeddings(model, tokenizer, sequences, device, average=average,
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
max_length=max_length)
except:
raise Exception(f"Could not generate {model_type} embeddings")
if model_type=="prot_t5_xl_half_uniref50_enc":
# Load model
try:
model, tokenizer, device = load_prott5()
except:
raise Exception(f"Could not load {model_type}")
# Generate embeddings
try:
get_prott5_embeddings(model, tokenizer, sequences, device, average=average,
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
max_length=max_length)
except:
raise Exception(f"Could not generate {model_type} embeddings")
def embed_dataset_for_benchmark(fuson_ckpts=None, input_data_path=None, input_fname=None, average=True, seq_col='seq', benchmark_fusonplm=False, benchmark_esm=False, benchmark_fo_puncta_ml=False, benchmark_prott5=False, overwrite=False,max_length=None):
# make directory for embeddings inside benchmarking dataset if one doesn't already eist
os.makedirs('embeddings',exist_ok=True)
# Extract input file name from configs
emb_type_tag ='average' if average else '2D'
all_embedding_paths = dict() # dictionary organized where embedding path points to model, epoch
# make the embedding files. Put them in an embedding directory
if benchmark_fusonplm:
os.makedirs('embeddings/fuson_plm',exist_ok=True)
log_update(f"\nMaking Fuson-PLM embeddings")
# make subdirs for all the
if type(fuson_ckpts)==dict:
for model_name, epoch_list in fuson_ckpts.items():
os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True)
for epoch in epoch_list:
# Assemble ckpt path and throw error if it doesn't exist
fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
# Make output directory and output embedding path
embedding_output_dir = f'embeddings/fuson_plm/{model_name}/epoch{epoch}'
embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
os.makedirs(embedding_output_dir,exist_ok=True)
# Make dictionary item
model_type = 'fuson_plm'
all_embedding_paths[embedding_output_path] = {
'model_type': model_type,
'model': model_name,
'epoch': epoch
}
# Create embeddings (or skip if they're already made)
log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
embed_dataset(input_data_path, embedding_output_path,
seq_col=seq_col, model_type=model_type,
fuson_ckpt_path=fuson_ckpt_path, average=average,
overwrite=overwrite,print_updates=True,
max_length=max_length)
elif fuson_ckpts=="FusOn-pLM":
model_name = "best"
os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True)
# Assemble ckpt path and throw error if it doesn't exist
fuson_ckpt_path = "../../.." # go back to the FusOn-pLM directory to find the best ckpt
if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
# Make output directory and output embedding path
embedding_output_dir = f'embeddings/fuson_plm/{model_name}'
embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
os.makedirs(embedding_output_dir,exist_ok=True)
# Make dictionary item
model_type = 'fuson_plm'
all_embedding_paths[embedding_output_path] = {
'model_type': model_type,
'model': model_name,
'epoch': None
}
# Create embeddings (or skip if they're already made)
log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
embed_dataset(input_data_path, embedding_output_path,
seq_col=seq_col, model_type=model_type,
fuson_ckpt_path=fuson_ckpt_path, average=average,
overwrite=overwrite,print_updates=True,
max_length=max_length)
else:
raise Exception(f"Error. fuson_ckpts should be a dict or str")
# make the embedding files. Put them in an embedding directory
if benchmark_esm:
os.makedirs('embeddings/esm2_t33_650M_UR50D',exist_ok=True)
# make output path
embedding_output_path = f'embeddings/esm2_t33_650M_UR50D/{input_fname}_{emb_type_tag}_embeddings.pkl'
# Make dictioary item
model_type = 'esm2_t33_650M_UR50D'
all_embedding_paths[embedding_output_path] = {
'model_type': model_type,
'model': model_type,
'epoch': np.nan
}
log_update(f"\nMaking ESM-2-650M embeddings for {input_data_path} and saving results to {embedding_output_path}...")
embed_dataset(input_data_path, embedding_output_path,
seq_col=seq_col, model_type=model_type,
fuson_ckpt_path = None, average=average,
overwrite=overwrite,print_updates=True,
max_length=max_length)
if benchmark_prott5:
os.makedirs('embeddings/prot_t5_xl_half_uniref50_enc',exist_ok=True)
# make output path
embedding_output_path = f'embeddings/prot_t5_xl_half_uniref50_enc/{input_fname}_{emb_type_tag}_embeddings.pkl'
# Make dictioary item
model_type = 'prot_t5_xl_half_uniref50_enc'
all_embedding_paths[embedding_output_path] = {
'model_type': model_type,
'model': model_type,
'epoch': np.nan
}
log_update(f"\nMaking ProtT5-XL-UniRef50 embeddings for {input_data_path} and saving results to {embedding_output_path}...")
embed_dataset(input_data_path, embedding_output_path,
seq_col=seq_col, model_type=model_type,
fuson_ckpt_path = None, average=average,
overwrite=overwrite,print_updates=True,
max_length=max_length)
if benchmark_fo_puncta_ml:
embedding_output_path =f'FOdb_physicochemical_embeddings.pkl'
# Make dictionary item
all_embedding_paths[embedding_output_path] = {
'model_type': 'fo_puncta_ml',
'model': 'fo_puncta_ml',
'epoch': np.nan
}
return all_embedding_paths