ndhieunguyen commited on
Commit
9c86aa3
·
1 Parent(s): bb79ce7

feat: progress bat

Browse files
app.py CHANGED
@@ -70,43 +70,43 @@ encoder = get_encoder()
70
  model = get_model()
71
  diffusion = get_diffusion()
72
 
73
- sample_fn = diffusion.ddim_sample_loop
74
-
75
  st.title("Lang2mol-Diff")
76
  text_input = st.text_area("Enter molecule description")
77
- with st.spinner("Please wait ..."):
78
- output = tokenizer(
79
- text_input,
80
- max_length=256,
81
- truncation=True,
82
- padding="max_length",
83
- add_special_tokens=True,
84
- return_tensors="pt",
85
- return_attention_mask=True,
86
- )
87
- caption_state = encoder(
88
- input_ids=output["input_ids"],
89
- attention_mask=output["attention_mask"],
90
- ).last_hidden_state
91
- caption_mask = output["attention_mask"]
 
 
92
 
93
- outputs = sample_fn(
94
- model,
95
- (1, 256, 32),
96
- clip_denoised=False,
97
- denoised_fn=None,
98
- model_kwargs={},
99
- top_p=1.0,
100
- progress=True,
101
- caption=(caption_state, caption_mask),
102
- )
103
- logits = model.get_logits(torch.tensor(outputs))
104
- cands = torch.topk(logits, k=1, dim=-1)
105
- outputs = cands.indices
106
- outputs = outputs.squeeze(-1)
107
- outputs = tokenizer.decode(outputs)
108
- result = sf.decoder(
109
- outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
110
- ).replace("\t", "")
111
 
112
- st.write(result)
 
70
  model = get_model()
71
  diffusion = get_diffusion()
72
 
 
 
73
  st.title("Lang2mol-Diff")
74
  text_input = st.text_area("Enter molecule description")
75
+ button = st.button("Submit")
76
+ if button:
77
+ with st.spinner("Please wait..."):
78
+ output = tokenizer(
79
+ text_input,
80
+ max_length=256,
81
+ truncation=True,
82
+ padding="max_length",
83
+ add_special_tokens=True,
84
+ return_tensors="pt",
85
+ return_attention_mask=True,
86
+ )
87
+ caption_state = encoder(
88
+ input_ids=output["input_ids"],
89
+ attention_mask=output["attention_mask"],
90
+ ).last_hidden_state
91
+ caption_mask = output["attention_mask"]
92
 
93
+ outputs = diffusion.p_sample_loop(
94
+ model,
95
+ (1, 256, 32),
96
+ clip_denoised=False,
97
+ denoised_fn=None,
98
+ model_kwargs={},
99
+ top_p=1.0,
100
+ progress=True,
101
+ caption=(caption_state, caption_mask),
102
+ )
103
+ logits = model.get_logits(torch.tensor(outputs))
104
+ cands = torch.topk(logits, k=1, dim=-1)
105
+ outputs = cands.indices
106
+ outputs = outputs.squeeze(-1)
107
+ outputs = tokenizer.decode(outputs)
108
+ result = sf.decoder(
109
+ outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
110
+ ).replace("\t", "")
111
 
112
+ st.write(result)
src/improved_diffusion/gaussian_diffusion.py CHANGED
@@ -9,7 +9,7 @@ import enum
9
  import math
10
  import torch
11
  import numpy as np
12
-
13
  from .nn import mean_flat
14
  from .losses import normal_kl, discretized_gaussian_log_likelihood
15
 
@@ -667,16 +667,19 @@ class GaussianDiffusion:
667
  # print(indices[-10:])
668
  if progress:
669
  # Lazy import so that we don't depend on tqdm.
670
- from tqdm.auto import tqdm
 
 
 
671
 
672
- indices = tqdm(indices)
673
  if caption is not None:
674
  print("Text Guiding Generation ......")
675
  caption = (
676
  caption[0].to(img.device),
677
  caption[1].to(img.device),
678
  ) # (caption_state, caption_mask)
679
- for i in indices:
 
680
  t = torch.tensor([i] * shape[0], device=device)
681
  with torch.no_grad():
682
  out = self.p_sample(
@@ -691,6 +694,8 @@ class GaussianDiffusion:
691
  )
692
  yield out
693
  img = out["sample"]
 
 
694
 
695
  def p_sample_loop_langevin_progressive(
696
  self,
 
9
  import math
10
  import torch
11
  import numpy as np
12
+ import streamlit as st
13
  from .nn import mean_flat
14
  from .losses import normal_kl, discretized_gaussian_log_likelihood
15
 
 
667
  # print(indices[-10:])
668
  if progress:
669
  # Lazy import so that we don't depend on tqdm.
670
+ # from tqdm.auto import tqdm
671
+
672
+ # indices = tqdm(indices)
673
+ pass
674
 
 
675
  if caption is not None:
676
  print("Text Guiding Generation ......")
677
  caption = (
678
  caption[0].to(img.device),
679
  caption[1].to(img.device),
680
  ) # (caption_state, caption_mask)
681
+ my_bar = st.progress(0, text="Processing")
682
+ for pro, i in enumerate(indices):
683
  t = torch.tensor([i] * shape[0], device=device)
684
  with torch.no_grad():
685
  out = self.p_sample(
 
694
  )
695
  yield out
696
  img = out["sample"]
697
+ my_bar.progress(pro + 1, text="Processing")
698
+ my_bar.empty()
699
 
700
  def p_sample_loop_langevin_progressive(
701
  self,