File size: 3,174 Bytes
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688df3c
 
 
 
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb79ce7
7dd9869
9c86aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dd9869
9c86aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dd9869
9c86aa3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import argparse
import selfies as sf
from tqdm import tqdm
from transformers import T5EncoderModel
from transformers import set_seed
from src.scripts.mytokenizers import Tokenizer
from src.improved_diffusion import gaussian_diffusion as gd
from src.improved_diffusion import dist_util, logger
from src.improved_diffusion.respace import SpacedDiffusion
from src.improved_diffusion.transformer_model import TransformerNetModel
from src.improved_diffusion.script_util import (
    model_and_diffusion_defaults,
    add_dict_to_argparser,
)
from src.scripts.mydatasets import Lang2molDataset_submission
import streamlit as st
import os


@st.cache_resource
def get_encoder():
    model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
    model.eval()
    return model


@st.cache_resource
def get_tokenizer():
    return Tokenizer()


@st.cache_resource
def get_model():
    model = TransformerNetModel(
        in_channels=32,
        model_channels=128,
        dropout=0.1,
        vocab_size=35073,
        hidden_size=1024,
        num_attention_heads=16,
        num_hidden_layers=12,
    )
    model.load_state_dict(
        torch.load(
            os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
            map_location=torch.device("cpu"),
        )
    )
    model.eval()
    return model


@st.cache_resource
def get_diffusion():
    return SpacedDiffusion(
        use_timesteps=[i for i in range(0, 2000, 10)],
        betas=gd.get_named_beta_schedule("sqrt", 2000),
        model_mean_type=(gd.ModelMeanType.START_X),
        model_var_type=((gd.ModelVarType.FIXED_LARGE)),
        loss_type=gd.LossType.E2E_MSE,
        rescale_timesteps=True,
        model_arch="transformer",
        training_mode="e2e",
    )


tokenizer = get_tokenizer()
encoder = get_encoder()
model = get_model()
diffusion = get_diffusion()

st.title("Lang2mol-Diff")
text_input = st.text_area("Enter molecule description")
button = st.button("Submit")
if button:
    with st.spinner("Please wait..."):
        output = tokenizer(
            text_input,
            max_length=256,
            truncation=True,
            padding="max_length",
            add_special_tokens=True,
            return_tensors="pt",
            return_attention_mask=True,
        )
        caption_state = encoder(
            input_ids=output["input_ids"],
            attention_mask=output["attention_mask"],
        ).last_hidden_state
        caption_mask = output["attention_mask"]

        outputs = diffusion.p_sample_loop(
            model,
            (1, 256, 32),
            clip_denoised=False,
            denoised_fn=None,
            model_kwargs={},
            top_p=1.0,
            progress=True,
            caption=(caption_state, caption_mask),
        )
        logits = model.get_logits(torch.tensor(outputs))
        cands = torch.topk(logits, k=1, dim=-1)
        outputs = cands.indices
        outputs = outputs.squeeze(-1)
        outputs = tokenizer.decode(outputs)
        result = sf.decoder(
            outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
        ).replace("\t", "")

        st.write(result)