ndhieunguyen commited on
Commit
bb79ce7
·
1 Parent(s): 688df3c

feat: spinner

Browse files
Files changed (1) hide show
  1. app.py +35 -33
app.py CHANGED
@@ -72,39 +72,41 @@ diffusion = get_diffusion()
72
 
73
  sample_fn = diffusion.ddim_sample_loop
74
 
 
75
  text_input = st.text_area("Enter molecule description")
76
- output = tokenizer(
77
- text_input,
78
- max_length=256,
79
- truncation=True,
80
- padding="max_length",
81
- add_special_tokens=True,
82
- return_tensors="pt",
83
- return_attention_mask=True,
84
- )
85
- caption_state = encoder(
86
- input_ids=output["input_ids"],
87
- attention_mask=output["attention_mask"],
88
- ).last_hidden_state
89
- caption_mask = output["attention_mask"]
 
90
 
91
- outputs = sample_fn(
92
- model,
93
- (1, 256, 32),
94
- clip_denoised=False,
95
- denoised_fn=None,
96
- model_kwargs={},
97
- top_p=1.0,
98
- progress=True,
99
- caption=(caption_state, caption_mask),
100
- )
101
- logits = model.get_logits(torch.tensor(outputs))
102
- cands = torch.topk(logits, k=1, dim=-1)
103
- outputs = cands.indices
104
- outputs = outputs.squeeze(-1)
105
- outputs = tokenizer.decode(outputs)
106
- result = sf.decoder(
107
- outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
108
- ).replace("\t", "")
109
 
110
- st.write(result)
 
72
 
73
  sample_fn = diffusion.ddim_sample_loop
74
 
75
+ st.title("Lang2mol-Diff")
76
  text_input = st.text_area("Enter molecule description")
77
+ with st.spinner("Please wait ..."):
78
+ output = tokenizer(
79
+ text_input,
80
+ max_length=256,
81
+ truncation=True,
82
+ padding="max_length",
83
+ add_special_tokens=True,
84
+ return_tensors="pt",
85
+ return_attention_mask=True,
86
+ )
87
+ caption_state = encoder(
88
+ input_ids=output["input_ids"],
89
+ attention_mask=output["attention_mask"],
90
+ ).last_hidden_state
91
+ caption_mask = output["attention_mask"]
92
 
93
+ outputs = sample_fn(
94
+ model,
95
+ (1, 256, 32),
96
+ clip_denoised=False,
97
+ denoised_fn=None,
98
+ model_kwargs={},
99
+ top_p=1.0,
100
+ progress=True,
101
+ caption=(caption_state, caption_mask),
102
+ )
103
+ logits = model.get_logits(torch.tensor(outputs))
104
+ cands = torch.topk(logits, k=1, dim=-1)
105
+ outputs = cands.indices
106
+ outputs = outputs.squeeze(-1)
107
+ outputs = tokenizer.decode(outputs)
108
+ result = sf.decoder(
109
+ outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
110
+ ).replace("\t", "")
111
 
112
+ st.write(result)