yasserrmd commited on
Commit
f301955
·
verified ·
1 Parent(s): e4530c3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from modeling_diffusion import DiffusionTextModel
4
+
5
+ # =====================
6
+ # Load Model from Hub
7
+ # =====================
8
+ model = DiffusionTextModel.from_pretrained("yasserrmd/diffusion-text-demo")
9
+ model.eval()
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model.to(device)
13
+
14
+ # Assume vocab, id_to_word, pad_id, mask_id already defined
15
+
16
+ # =====================
17
+ # Generation Function
18
+ # =====================
19
+ def generate_with_prompt(model, input_text, max_length=50, T=10):
20
+ # Ensure max_length does not exceed 99
21
+ max_length = min(max_length, 99)
22
+
23
+ model.eval()
24
+ input_tokens = input_text.split()
25
+ input_ids = [vocab.get(tok, mask_id) for tok in input_tokens]
26
+
27
+ seq = torch.full((1, max_length), mask_id, dtype=torch.long, device=device)
28
+ seq[0, :len(input_ids)] = torch.tensor(input_ids, device=device)
29
+
30
+ for step in range(T, 0, -1):
31
+ with torch.no_grad():
32
+ logits = model(seq, torch.tensor([step], device=device))
33
+ probs = torch.softmax(logits, dim=-1)
34
+ for pos in range(len(input_ids), max_length):
35
+ if seq[0, pos].item() == mask_id:
36
+ seq[0, pos] = torch.multinomial(probs[0, pos], 1)
37
+
38
+ ids = seq[0].tolist()
39
+ if pad_id in ids:
40
+ ids = ids[:ids.index(pad_id)]
41
+ return " ".join(id_to_word[i] for i in ids)
42
+
43
+ # =====================
44
+ # Gradio App
45
+ # =====================
46
+ def chat_fn(message, history, steps, max_len):
47
+ response = generate_with_prompt(model, message, max_length=max_len, T=steps)
48
+ history.append((message, response))
49
+ return "", history
50
+
51
+ with gr.Blocks() as demo:
52
+ gr.Markdown("## 🌀 DiffusionTextModel QA Chat Demo")
53
+ chatbot = gr.Chatbot()
54
+ msg = gr.Textbox(placeholder="Type your question or prompt here...")
55
+ steps = gr.Slider(1, 50, value=10, step=1, label="Diffusion Steps (T)")
56
+ max_len = gr.Slider(10, 99, value=50, step=1, label="Max Token Length (≤ 99)")
57
+ clear = gr.Button("Clear")
58
+
59
+ msg.submit(chat_fn, [msg, chatbot, steps, max_len], [msg, chatbot])
60
+ clear.click(lambda: None, None, chatbot, queue=False)
61
+
62
+ demo.launch()