feiyang-cai's picture
update
7f3de76
raw
history blame
13.5 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import os
from typing import Optional, Dict, Sequence
import transformers
from peft import PeftModel
import torch
from dataclasses import dataclass, field
from huggingface_hub import hf_hub_download
import json
import pandas as pd
from datasets import Dataset
from tqdm import tqdm
import spaces
from llama_customized_models import LlamaForCausalLMWithNumericalEmbedding
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from torch.utils.data.dataloader import DataLoader
from torch.nn import functional as F
import importlib
from rdkit import RDLogger, Chem
# Suppress RDKit INFO messages
RDLogger.DisableLog('rdApp.*')
DEFAULT_PAD_TOKEN = "[PAD]"
device_map = "cuda"
means = {"qed": 0.5559003125710424, "logp": 3.497542110420217, "sas": 2.889429694406497, "tpsa": 80.19717097706841}
stds = {"qed": 0.21339854620824716, "logp": 1.7923582437824368, "sas": 0.8081188219568571, "tpsa": 38.212259443049554}
def phrase_df(df):
metric_calculator = importlib.import_module("metric_calculator")
new_df = []
# iterate over the dataframe
for i in range(len(df)):
sub_df = dict()
# get the SMILES
smiles = df.iloc[i]['SMILES']
# get the property names
property_names = df.iloc[i]['property_names']
# get the non normalized properties
non_normalized_properties = df.iloc[i]['non_normalized_properties']
sub_df['SMILES'] = smiles
# compute the similarity between the scaffold and the SMILES
for j in range(len(property_names)):
# get the property name
property_name = property_names[j]
# get the non normalized property
non_normalized_property = non_normalized_properties[j]
sub_df[f'{property_name}_condition'] = non_normalized_property
if smiles == "":
sub_df[f'{property_name}_measured'] = np.nan
else:
property_eval_func_name = f"compute_{property_name}"
property_eval_func = getattr(metric_calculator, property_eval_func_name)
sub_df[f'{property_name}_measured'] = property_eval_func(Chem.MolFromSmiles(smiles))
new_df.append(sub_df)
new_df = pd.DataFrame(new_df)
return new_df
@dataclass
class DataCollatorForCausalLMEval(object):
tokenizer: transformers.PreTrainedTokenizer
source_max_len: int
target_max_len: int
molecule_target_aug_prob: float
molecule_start_str: str
scaffold_aug_prob: float
scaffold_start_str: str
property_start_str: str
property_inner_sep: str
property_inter_sep: str
end_str: str
ignore_index: int
has_scaffold: bool
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
# Extract elements
prop_token_map = {
'qed': '<qed>',
'logp': '<logp>',
'sas': '<SAS>',
'tpsa': '<TPSA>'
}
sources = []
props_list = []
non_normalized_props_list = []
prop_names_list = []
props_index_list = []
temperature_list = []
scaffold_list = []
for example in instances:
prop_names = example['property_name']
prop_values = example['property_value']
non_normalized_prop_values = example['non_normalized_property_value']
temperature = example['temperature']
# we need to convert the string to a list
# randomly choose the property and the scaffold combinations:
props_str = ""
scaffold_str = ""
props = []
non_nornalized_props = []
props_index = []
if self.has_scaffold:
scaffold = example['scaffold_smiles'].strip()
scaffold_str = f"{self.scaffold_start_str}{scaffold}{self.end_str}"
props_str = f"{self.property_start_str}"
for i, prop in enumerate(prop_names):
prop = prop.lower()
props_str += f"{prop_token_map[prop]}{self.property_inner_sep}{self.molecule_start_str}{self.property_inter_sep}"
props.append(prop_values[i])
non_nornalized_props.append(non_normalized_prop_values[i])
props_index.append(3 + 4 * i) # this is hard coded for the current template
props_str += f"{self.end_str}"
source = props_str + scaffold_str + "<->>" + self.molecule_start_str
sources.append(source)
props_list.append(props)
non_normalized_props_list.append(non_nornalized_props)
props_index_list.append(props_index)
prop_names_list.append(prop_names)
temperature_list.append(temperature)
# Tokenize
tokenized_sources_with_prompt = self.tokenizer(
sources,
max_length=self.source_max_len,
truncation=True,
add_special_tokens=False,
)
# Build the input and labels for causal LM
input_ids = []
for tokenized_source in tokenized_sources_with_prompt['input_ids']:
input_ids.append(torch.tensor(tokenized_source))
# Apply padding
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
data_dict = {
'input_ids': input_ids,
'attention_mask':input_ids.ne(self.tokenizer.pad_token_id),
'properties': props_list,
'non_normalized_properties': non_normalized_props_list,
'property_names': prop_names_list,
'properties_index': props_index_list,
'temperature': temperature_list,
}
return data_dict
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
non_special_tokens = None,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens)
num_old_tokens = model.get_input_embeddings().weight.shape[0]
num_new_tokens = len(tokenizer) - num_old_tokens
if num_new_tokens == 0:
return
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings_data = model.get_input_embeddings().weight.data
input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.")
class MolecularGenerationModel():
def __init__(self):
model_id = "ChemFM/molecular_cond_generation_guacamol"
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
padding_side="right",
use_fast=True,
trust_remote_code=True,
token = os.environ.get("TOKEN")
)
# load model
config = AutoConfig.from_pretrained(
model_id,
device_map=device_map,
trust_remote_code=True,
token = os.environ.get("TOKEN")
)
self.model = LlamaForCausalLMWithNumericalEmbedding.from_pretrained(
model_id,
config=config,
device_map=device_map,
trust_remote_code=True,
token = os.environ.get("TOKEN")
)
# the finetune tokenizer could be in different size with pretrain tokenizer, and also, we need to add PAD_TOKEN
special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
smart_tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=self.tokenizer,
model=self.model
)
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.model.eval()
string_template_path = hf_hub_download(model_id, filename="string_template.json", token = os.environ.get("TOKEN"))
string_template = json.load(open(string_template_path, 'r'))
molecule_start_str = string_template['MOLECULE_START_STRING']
scaffold_start_str = string_template['SCAFFOLD_MOLECULE_START_STRING']
property_start_str = string_template['PROPERTY_START_STRING']
property_inner_sep = string_template['PROPERTY_INNER_SEP']
property_inter_sep = string_template['PROPERTY_INTER_SEP']
end_str = string_template['END_STRING']
self.data_collator = DataCollatorForCausalLMEval(
tokenizer=self.tokenizer,
source_max_len=512,
target_max_len=512,
molecule_target_aug_prob=1.0,
scaffold_aug_prob=0.0,
molecule_start_str=molecule_start_str,
scaffold_start_str=scaffold_start_str,
property_start_str=property_start_str,
property_inner_sep=property_inner_sep,
property_inter_sep=property_inter_sep,
end_str=end_str,
ignore_index=-100,
has_scaffold=False
)
#@spaces.GPU(duration=20)
def generate(self, loader):
#self.model.to("cuda")
#self.model.eval()
df = []
pbar = tqdm(loader, desc=f"Evaluating...", leave=False)
for it, batch in enumerate(pbar):
sub_df = dict()
batch_size = batch['input_ids'].shape[0]
assert batch_size == 1, "The batch size should be 1"
temperature = batch['temperature'][0]
property_names = batch['property_names'][0]
non_normalized_properties = batch['non_normalized_properties'][0]
num_generations = 1
del batch['temperature']
del batch['property_names']
del batch['non_normalized_properties']
batch['input_ids'] = batch['input_ids'].to(self.model.device)
#batch = {k: v.to(self.model.device) for k, v in batch.items()}
input_length = batch['input_ids'].shape[1]
steps = 1024 - input_length
print(self.model.device, "model_device")
with torch.set_grad_enabled(False):
early_stop_flags = torch.zeros(num_generations, dtype=torch.bool).to(self.model.device)
for k in range(steps):
logits = self.model(**batch)['logits']
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
ix = torch.multinomial(probs, num_samples=num_generations)
ix[early_stop_flags] = self.tokenizer.eos_token_id
batch['input_ids'] = torch.cat([batch['input_ids'], ix], dim=-1)
early_stop_flags |= (ix.squeeze() == self.tokenizer.eos_token_id)
if torch.all(early_stop_flags):
break
generations = self.tokenizer.batch_decode(batch['input_ids'][:, input_length:], skip_special_tokens=True)
generations = map(lambda x: x.replace(" ", ""), generations)
predictions = []
for generation in generations:
try:
predictions.append(Chem.MolToSmiles(Chem.MolFromSmiles(generation)))
except:
predictions.append("")
sub_df['SMILES'] = predictions[0]
sub_df['property_names'] = property_names
sub_df['property'] = batch['properties'][0]
sub_df['non_normalized_properties'] = non_normalized_properties
df.append(sub_df)
df = pd.DataFrame(df)
return df
def predict_single_smiles(self, input_dict: Dict):
# conver the key to lower case
input_dict = {key.lower(): value for key, value in input_dict.items()}
properties = [key.lower() for key in input_dict.keys()]
property_means = [means[prop] for prop in properties]
property_stds = [stds[prop] for prop in properties]
sample_point = [input_dict[prop] for prop in properties]
non_normalized_sample_point = np.array(sample_point).reshape(-1)
sample_point = (np.array(sample_point) - np.array(property_means)) / np.array(property_stds)
sub_df = {
"property_name": properties,
"property_value": sample_point.tolist(),
"temperature": 1.0,
"non_normalized_property_value": non_normalized_sample_point.tolist()
}
test_dataset = [sub_df] * 10
test_dataset = pd.DataFrame(test_dataset)
test_dataset = Dataset.from_pandas(test_dataset)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=self.data_collator)
df = self.generate(test_loader)
new_df = phrase_df(df)
# delete the condition columns
new_df = new_df.drop(columns=[col for col in new_df.columns if "condition" in col])
# drop the rows
df = df[df["SMILES"] != ""]
# convert the measured to 2 decimal places
new_df = new_df.round(2)
return new_df