Spaces:
Running
Running
Commit
·
23a7a4b
1
Parent(s):
22c5f0f
feat: add device
Browse files
app.py
CHANGED
|
@@ -1,26 +1,18 @@
|
|
| 1 |
import torch
|
| 2 |
-
import argparse
|
| 3 |
import selfies as sf
|
| 4 |
-
from tqdm import tqdm
|
| 5 |
from transformers import T5EncoderModel
|
| 6 |
-
from transformers import set_seed
|
| 7 |
from src.scripts.mytokenizers import Tokenizer
|
| 8 |
from src.improved_diffusion import gaussian_diffusion as gd
|
| 9 |
-
from src.improved_diffusion import dist_util, logger
|
| 10 |
from src.improved_diffusion.respace import SpacedDiffusion
|
| 11 |
from src.improved_diffusion.transformer_model import TransformerNetModel
|
| 12 |
-
from src.improved_diffusion.script_util import (
|
| 13 |
-
model_and_diffusion_defaults,
|
| 14 |
-
add_dict_to_argparser,
|
| 15 |
-
)
|
| 16 |
-
from src.scripts.mydatasets import Lang2molDataset_submission
|
| 17 |
import streamlit as st
|
| 18 |
import os
|
| 19 |
|
| 20 |
|
| 21 |
@st.cache_resource
|
| 22 |
-
def get_encoder():
|
| 23 |
model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
|
|
|
|
| 24 |
model.eval()
|
| 25 |
return model
|
| 26 |
|
|
@@ -31,7 +23,7 @@ def get_tokenizer():
|
|
| 31 |
|
| 32 |
|
| 33 |
@st.cache_resource
|
| 34 |
-
def get_model():
|
| 35 |
model = TransformerNetModel(
|
| 36 |
in_channels=32,
|
| 37 |
model_channels=128,
|
|
@@ -44,9 +36,10 @@ def get_model():
|
|
| 44 |
model.load_state_dict(
|
| 45 |
torch.load(
|
| 46 |
os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
|
| 47 |
-
map_location=torch.device(
|
| 48 |
)
|
| 49 |
)
|
|
|
|
| 50 |
model.eval()
|
| 51 |
return model
|
| 52 |
|
|
@@ -65,9 +58,11 @@ def get_diffusion():
|
|
| 65 |
)
|
| 66 |
|
| 67 |
|
|
|
|
|
|
|
| 68 |
tokenizer = get_tokenizer()
|
| 69 |
-
encoder = get_encoder()
|
| 70 |
-
model = get_model()
|
| 71 |
diffusion = get_diffusion()
|
| 72 |
|
| 73 |
st.title("Lang2mol-Diff")
|
|
@@ -85,8 +80,8 @@ if button:
|
|
| 85 |
return_attention_mask=True,
|
| 86 |
)
|
| 87 |
caption_state = encoder(
|
| 88 |
-
input_ids=output["input_ids"],
|
| 89 |
-
attention_mask=output["attention_mask"],
|
| 90 |
).last_hidden_state
|
| 91 |
caption_mask = output["attention_mask"]
|
| 92 |
|
|
@@ -98,7 +93,7 @@ if button:
|
|
| 98 |
model_kwargs={},
|
| 99 |
top_p=1.0,
|
| 100 |
progress=True,
|
| 101 |
-
caption=(caption_state, caption_mask),
|
| 102 |
)
|
| 103 |
logits = model.get_logits(torch.tensor(outputs))
|
| 104 |
cands = torch.topk(logits, k=1, dim=-1)
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
import selfies as sf
|
|
|
|
| 3 |
from transformers import T5EncoderModel
|
|
|
|
| 4 |
from src.scripts.mytokenizers import Tokenizer
|
| 5 |
from src.improved_diffusion import gaussian_diffusion as gd
|
|
|
|
| 6 |
from src.improved_diffusion.respace import SpacedDiffusion
|
| 7 |
from src.improved_diffusion.transformer_model import TransformerNetModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import streamlit as st
|
| 9 |
import os
|
| 10 |
|
| 11 |
|
| 12 |
@st.cache_resource
|
| 13 |
+
def get_encoder(device):
|
| 14 |
model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
|
| 15 |
+
model.to(device)
|
| 16 |
model.eval()
|
| 17 |
return model
|
| 18 |
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
@st.cache_resource
|
| 26 |
+
def get_model(device):
|
| 27 |
model = TransformerNetModel(
|
| 28 |
in_channels=32,
|
| 29 |
model_channels=128,
|
|
|
|
| 36 |
model.load_state_dict(
|
| 37 |
torch.load(
|
| 38 |
os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
|
| 39 |
+
map_location=torch.device(device),
|
| 40 |
)
|
| 41 |
)
|
| 42 |
+
model.to(device)
|
| 43 |
model.eval()
|
| 44 |
return model
|
| 45 |
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
|
| 61 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 62 |
+
|
| 63 |
tokenizer = get_tokenizer()
|
| 64 |
+
encoder = get_encoder(device)
|
| 65 |
+
model = get_model(device)
|
| 66 |
diffusion = get_diffusion()
|
| 67 |
|
| 68 |
st.title("Lang2mol-Diff")
|
|
|
|
| 80 |
return_attention_mask=True,
|
| 81 |
)
|
| 82 |
caption_state = encoder(
|
| 83 |
+
input_ids=output["input_ids"].to(device),
|
| 84 |
+
attention_mask=output["attention_mask"].to(device),
|
| 85 |
).last_hidden_state
|
| 86 |
caption_mask = output["attention_mask"]
|
| 87 |
|
|
|
|
| 93 |
model_kwargs={},
|
| 94 |
top_p=1.0,
|
| 95 |
progress=True,
|
| 96 |
+
caption=(caption_state.to(device), caption_mask.to(device)),
|
| 97 |
)
|
| 98 |
logits = model.get_logits(torch.tensor(outputs))
|
| 99 |
cands = torch.topk(logits, k=1, dim=-1)
|