Spaces:
Running
on
Zero
Running
on
Zero
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
import os | |
from typing import Optional, Dict, Sequence | |
import transformers | |
from peft import PeftModel | |
import torch | |
from dataclasses import dataclass, field | |
from huggingface_hub import hf_hub_download | |
import json | |
import pandas as pd | |
from datasets import Dataset | |
from tqdm import tqdm | |
import spaces | |
from llama_customized_models import LlamaForCausalLMWithNumericalEmbedding | |
from torch.nn.utils.rnn import pad_sequence | |
import numpy as np | |
from torch.utils.data.dataloader import DataLoader | |
from torch.nn import functional as F | |
import importlib | |
from rdkit import RDLogger, Chem | |
# Suppress RDKit INFO messages | |
RDLogger.DisableLog('rdApp.*') | |
DEFAULT_PAD_TOKEN = "[PAD]" | |
device_map = "cuda" | |
means = {"qed": 0.5559003125710424, "logp": 3.497542110420217, "sas": 2.889429694406497, "tpsa": 80.19717097706841} | |
stds = {"qed": 0.21339854620824716, "logp": 1.7923582437824368, "sas": 0.8081188219568571, "tpsa": 38.212259443049554} | |
def phrase_df(df): | |
metric_calculator = importlib.import_module("metric_calculator") | |
new_df = [] | |
# iterate over the dataframe | |
for i in range(len(df)): | |
sub_df = dict() | |
# get the SMILES | |
smiles = df.iloc[i]['SMILES'] | |
# get the property names | |
property_names = df.iloc[i]['property_names'] | |
# get the non normalized properties | |
non_normalized_properties = df.iloc[i]['non_normalized_properties'] | |
sub_df['SMILES'] = smiles | |
# compute the similarity between the scaffold and the SMILES | |
for j in range(len(property_names)): | |
# get the property name | |
property_name = property_names[j] | |
# get the non normalized property | |
non_normalized_property = non_normalized_properties[j] | |
sub_df[f'{property_name}_condition'] = non_normalized_property | |
if smiles == "": | |
sub_df[f'{property_name}_measured'] = np.nan | |
else: | |
property_eval_func_name = f"compute_{property_name}" | |
property_eval_func = getattr(metric_calculator, property_eval_func_name) | |
sub_df[f'{property_name}_measured'] = property_eval_func(Chem.MolFromSmiles(smiles)) | |
new_df.append(sub_df) | |
new_df = pd.DataFrame(new_df) | |
return new_df | |
class DataCollatorForCausalLMEval(object): | |
tokenizer: transformers.PreTrainedTokenizer | |
source_max_len: int | |
target_max_len: int | |
molecule_target_aug_prob: float | |
molecule_start_str: str | |
scaffold_aug_prob: float | |
scaffold_start_str: str | |
property_start_str: str | |
property_inner_sep: str | |
property_inter_sep: str | |
end_str: str | |
ignore_index: int | |
has_scaffold: bool | |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: | |
# Extract elements | |
prop_token_map = { | |
'qed': '<qed>', | |
'logp': '<logp>', | |
'sas': '<SAS>', | |
'tpsa': '<TPSA>' | |
} | |
sources = [] | |
props_list = [] | |
non_normalized_props_list = [] | |
prop_names_list = [] | |
props_index_list = [] | |
temperature_list = [] | |
scaffold_list = [] | |
for example in instances: | |
prop_names = example['property_name'] | |
prop_values = example['property_value'] | |
non_normalized_prop_values = example['non_normalized_property_value'] | |
temperature = example['temperature'] | |
# we need to convert the string to a list | |
# randomly choose the property and the scaffold combinations: | |
props_str = "" | |
scaffold_str = "" | |
props = [] | |
non_nornalized_props = [] | |
props_index = [] | |
if self.has_scaffold: | |
scaffold = example['scaffold_smiles'].strip() | |
scaffold_str = f"{self.scaffold_start_str}{scaffold}{self.end_str}" | |
props_str = f"{self.property_start_str}" | |
for i, prop in enumerate(prop_names): | |
prop = prop.lower() | |
props_str += f"{prop_token_map[prop]}{self.property_inner_sep}{self.molecule_start_str}{self.property_inter_sep}" | |
props.append(prop_values[i]) | |
non_nornalized_props.append(non_normalized_prop_values[i]) | |
props_index.append(3 + 4 * i) # this is hard coded for the current template | |
props_str += f"{self.end_str}" | |
source = props_str + scaffold_str + "<->>" + self.molecule_start_str | |
sources.append(source) | |
props_list.append(props) | |
non_normalized_props_list.append(non_nornalized_props) | |
props_index_list.append(props_index) | |
prop_names_list.append(prop_names) | |
temperature_list.append(temperature) | |
# Tokenize | |
tokenized_sources_with_prompt = self.tokenizer( | |
sources, | |
max_length=self.source_max_len, | |
truncation=True, | |
add_special_tokens=False, | |
) | |
# Build the input and labels for causal LM | |
input_ids = [] | |
for tokenized_source in tokenized_sources_with_prompt['input_ids']: | |
input_ids.append(torch.tensor(tokenized_source)) | |
# Apply padding | |
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) | |
data_dict = { | |
'input_ids': input_ids, | |
'attention_mask':input_ids.ne(self.tokenizer.pad_token_id), | |
'properties': props_list, | |
'non_normalized_properties': non_normalized_props_list, | |
'property_names': prop_names_list, | |
'properties_index': props_index_list, | |
'temperature': temperature_list, | |
} | |
return data_dict | |
def smart_tokenizer_and_embedding_resize( | |
special_tokens_dict: Dict, | |
tokenizer: transformers.PreTrainedTokenizer, | |
model: transformers.PreTrainedModel, | |
non_special_tokens = None, | |
): | |
"""Resize tokenizer and embedding. | |
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. | |
""" | |
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens) | |
num_old_tokens = model.get_input_embeddings().weight.shape[0] | |
num_new_tokens = len(tokenizer) - num_old_tokens | |
if num_new_tokens == 0: | |
return | |
model.resize_token_embeddings(len(tokenizer)) | |
if num_new_tokens > 0: | |
input_embeddings_data = model.get_input_embeddings().weight.data | |
input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) | |
input_embeddings_data[-num_new_tokens:] = input_embeddings_avg | |
print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.") | |
class MolecularGenerationModel(): | |
def __init__(self): | |
model_id = "ChemFM/molecular_cond_generation_guacamol" | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_id, | |
padding_side="right", | |
use_fast=True, | |
trust_remote_code=True, | |
token = os.environ.get("TOKEN") | |
) | |
# load model | |
config = AutoConfig.from_pretrained( | |
model_id, | |
device_map=device_map, | |
trust_remote_code=True, | |
token = os.environ.get("TOKEN") | |
) | |
self.model = LlamaForCausalLMWithNumericalEmbedding.from_pretrained( | |
model_id, | |
config=config, | |
device_map=device_map, | |
trust_remote_code=True, | |
token = os.environ.get("TOKEN") | |
) | |
# the finetune tokenizer could be in different size with pretrain tokenizer, and also, we need to add PAD_TOKEN | |
special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN) | |
smart_tokenizer_and_embedding_resize( | |
special_tokens_dict=special_tokens_dict, | |
tokenizer=self.tokenizer, | |
model=self.model | |
) | |
self.model.config.pad_token_id = self.tokenizer.pad_token_id | |
self.model.eval() | |
string_template_path = hf_hub_download(model_id, filename="string_template.json", token = os.environ.get("TOKEN")) | |
string_template = json.load(open(string_template_path, 'r')) | |
molecule_start_str = string_template['MOLECULE_START_STRING'] | |
scaffold_start_str = string_template['SCAFFOLD_MOLECULE_START_STRING'] | |
property_start_str = string_template['PROPERTY_START_STRING'] | |
property_inner_sep = string_template['PROPERTY_INNER_SEP'] | |
property_inter_sep = string_template['PROPERTY_INTER_SEP'] | |
end_str = string_template['END_STRING'] | |
self.data_collator = DataCollatorForCausalLMEval( | |
tokenizer=self.tokenizer, | |
source_max_len=512, | |
target_max_len=512, | |
molecule_target_aug_prob=1.0, | |
scaffold_aug_prob=0.0, | |
molecule_start_str=molecule_start_str, | |
scaffold_start_str=scaffold_start_str, | |
property_start_str=property_start_str, | |
property_inner_sep=property_inner_sep, | |
property_inter_sep=property_inter_sep, | |
end_str=end_str, | |
ignore_index=-100, | |
has_scaffold=False | |
) | |
def generate(self, loader): | |
df = [] | |
pbar = tqdm(loader, desc=f"Evaluating...", leave=False) | |
for it, batch in enumerate(pbar): | |
sub_df = dict() | |
batch_size = batch['input_ids'].shape[0] | |
assert batch_size == 1, "The batch size should be 1" | |
temperature = batch['temperature'][0] | |
property_names = batch['property_names'][0] | |
non_normalized_properties = batch['non_normalized_properties'][0] | |
num_generations = 1 | |
del batch['temperature'] | |
del batch['property_names'] | |
del batch['non_normalized_properties'] | |
input_length = batch['input_ids'].shape[1] | |
steps = 1024 - input_length | |
with torch.set_grad_enabled(False): | |
early_stop_flags = torch.zeros(num_generations, dtype=torch.bool).to(self.model.device) | |
for k in range(steps): | |
logits = self.model(**batch)['logits'] | |
logits = logits[:, -1, :] / temperature | |
probs = F.softmax(logits, dim=-1) | |
ix = torch.multinomial(probs, num_samples=num_generations) | |
ix[early_stop_flags] = self.tokenizer.eos_token_id | |
batch['input_ids'] = torch.cat([batch['input_ids'], ix], dim=-1) | |
early_stop_flags |= (ix.squeeze() == self.tokenizer.eos_token_id) | |
if torch.all(early_stop_flags): | |
break | |
generations = self.tokenizer.batch_decode(batch['input_ids'][:, input_length:], skip_special_tokens=True) | |
generations = map(lambda x: x.replace(" ", ""), generations) | |
predictions = [] | |
for generation in generations: | |
try: | |
predictions.append(Chem.MolToSmiles(Chem.MolFromSmiles(generation))) | |
except: | |
predictions.append("") | |
sub_df['SMILES'] = predictions[0] | |
sub_df['property_names'] = property_names | |
sub_df['property'] = batch['properties'][0] | |
sub_df['non_normalized_properties'] = non_normalized_properties | |
df.append(sub_df) | |
df = pd.DataFrame(df) | |
return df | |
def predict_single_smiles(self, input_dict: Dict): | |
# conver the key to lower case | |
input_dict = {key.lower(): value for key, value in input_dict.items()} | |
properties = [key.lower() for key in input_dict.keys()] | |
property_means = [means[prop] for prop in properties] | |
property_stds = [stds[prop] for prop in properties] | |
sample_point = [input_dict[prop] for prop in properties] | |
non_normalized_sample_point = np.array(sample_point).reshape(-1) | |
sample_point = (np.array(sample_point) - np.array(property_means)) / np.array(property_stds) | |
sub_df = { | |
"property_name": properties, | |
"property_value": sample_point.tolist(), | |
"temperature": 1.0, | |
"non_normalized_property_value": non_normalized_sample_point.tolist() | |
} | |
test_dataset = [sub_df] * 10 | |
test_dataset = pd.DataFrame(test_dataset) | |
test_dataset = Dataset.from_pandas(test_dataset) | |
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=self.data_collator) | |
df = self.generate(test_loader) | |
new_df = phrase_df(df) | |
# delete the condition columns | |
new_df = new_df.drop(columns=[col for col in new_df.columns if "condition" in col]) | |
# drop the empty smiles rows | |
new_df = new_df.dropna(subset=['SMILES']) | |
# convert the measured to 2 decimal places | |
new_df = new_df.round(2) | |
return new_df | |