Spaces:
Running
Running
Commit
·
23a7a4b
1
Parent(s):
22c5f0f
feat: add device
Browse files
app.py
CHANGED
@@ -1,26 +1,18 @@
|
|
1 |
import torch
|
2 |
-
import argparse
|
3 |
import selfies as sf
|
4 |
-
from tqdm import tqdm
|
5 |
from transformers import T5EncoderModel
|
6 |
-
from transformers import set_seed
|
7 |
from src.scripts.mytokenizers import Tokenizer
|
8 |
from src.improved_diffusion import gaussian_diffusion as gd
|
9 |
-
from src.improved_diffusion import dist_util, logger
|
10 |
from src.improved_diffusion.respace import SpacedDiffusion
|
11 |
from src.improved_diffusion.transformer_model import TransformerNetModel
|
12 |
-
from src.improved_diffusion.script_util import (
|
13 |
-
model_and_diffusion_defaults,
|
14 |
-
add_dict_to_argparser,
|
15 |
-
)
|
16 |
-
from src.scripts.mydatasets import Lang2molDataset_submission
|
17 |
import streamlit as st
|
18 |
import os
|
19 |
|
20 |
|
21 |
@st.cache_resource
|
22 |
-
def get_encoder():
|
23 |
model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
|
|
|
24 |
model.eval()
|
25 |
return model
|
26 |
|
@@ -31,7 +23,7 @@ def get_tokenizer():
|
|
31 |
|
32 |
|
33 |
@st.cache_resource
|
34 |
-
def get_model():
|
35 |
model = TransformerNetModel(
|
36 |
in_channels=32,
|
37 |
model_channels=128,
|
@@ -44,9 +36,10 @@ def get_model():
|
|
44 |
model.load_state_dict(
|
45 |
torch.load(
|
46 |
os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
|
47 |
-
map_location=torch.device(
|
48 |
)
|
49 |
)
|
|
|
50 |
model.eval()
|
51 |
return model
|
52 |
|
@@ -65,9 +58,11 @@ def get_diffusion():
|
|
65 |
)
|
66 |
|
67 |
|
|
|
|
|
68 |
tokenizer = get_tokenizer()
|
69 |
-
encoder = get_encoder()
|
70 |
-
model = get_model()
|
71 |
diffusion = get_diffusion()
|
72 |
|
73 |
st.title("Lang2mol-Diff")
|
@@ -85,8 +80,8 @@ if button:
|
|
85 |
return_attention_mask=True,
|
86 |
)
|
87 |
caption_state = encoder(
|
88 |
-
input_ids=output["input_ids"],
|
89 |
-
attention_mask=output["attention_mask"],
|
90 |
).last_hidden_state
|
91 |
caption_mask = output["attention_mask"]
|
92 |
|
@@ -98,7 +93,7 @@ if button:
|
|
98 |
model_kwargs={},
|
99 |
top_p=1.0,
|
100 |
progress=True,
|
101 |
-
caption=(caption_state, caption_mask),
|
102 |
)
|
103 |
logits = model.get_logits(torch.tensor(outputs))
|
104 |
cands = torch.topk(logits, k=1, dim=-1)
|
|
|
1 |
import torch
|
|
|
2 |
import selfies as sf
|
|
|
3 |
from transformers import T5EncoderModel
|
|
|
4 |
from src.scripts.mytokenizers import Tokenizer
|
5 |
from src.improved_diffusion import gaussian_diffusion as gd
|
|
|
6 |
from src.improved_diffusion.respace import SpacedDiffusion
|
7 |
from src.improved_diffusion.transformer_model import TransformerNetModel
|
|
|
|
|
|
|
|
|
|
|
8 |
import streamlit as st
|
9 |
import os
|
10 |
|
11 |
|
12 |
@st.cache_resource
|
13 |
+
def get_encoder(device):
|
14 |
model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
|
15 |
+
model.to(device)
|
16 |
model.eval()
|
17 |
return model
|
18 |
|
|
|
23 |
|
24 |
|
25 |
@st.cache_resource
|
26 |
+
def get_model(device):
|
27 |
model = TransformerNetModel(
|
28 |
in_channels=32,
|
29 |
model_channels=128,
|
|
|
36 |
model.load_state_dict(
|
37 |
torch.load(
|
38 |
os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
|
39 |
+
map_location=torch.device(device),
|
40 |
)
|
41 |
)
|
42 |
+
model.to(device)
|
43 |
model.eval()
|
44 |
return model
|
45 |
|
|
|
58 |
)
|
59 |
|
60 |
|
61 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
62 |
+
|
63 |
tokenizer = get_tokenizer()
|
64 |
+
encoder = get_encoder(device)
|
65 |
+
model = get_model(device)
|
66 |
diffusion = get_diffusion()
|
67 |
|
68 |
st.title("Lang2mol-Diff")
|
|
|
80 |
return_attention_mask=True,
|
81 |
)
|
82 |
caption_state = encoder(
|
83 |
+
input_ids=output["input_ids"].to(device),
|
84 |
+
attention_mask=output["attention_mask"].to(device),
|
85 |
).last_hidden_state
|
86 |
caption_mask = output["attention_mask"]
|
87 |
|
|
|
93 |
model_kwargs={},
|
94 |
top_p=1.0,
|
95 |
progress=True,
|
96 |
+
caption=(caption_state.to(device), caption_mask.to(device)),
|
97 |
)
|
98 |
logits = model.get_logits(torch.tensor(outputs))
|
99 |
cands = torch.topk(logits, k=1, dim=-1)
|