HyperX-Sen commited on
Commit
5c37b73
·
verified ·
1 Parent(s): ea29aa7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -61,21 +61,19 @@ def chat_response(user_input, top_p, top_k, temperature, max_length):
61
  input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
62
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
63
 
 
 
64
  with torch.no_grad():
65
- stream = model.generate(
66
  **inputs,
67
  max_length=max_length,
68
  do_sample=True,
69
  top_p=top_p,
70
  top_k=top_k,
71
- temperature=temperature,
72
- streamer=True
73
- )
74
-
75
- full_response = ""
76
- for token in stream:
77
- full_response += tokenizer.decode(token, skip_special_tokens=True)
78
- yield extract_response(full_response)
79
 
80
  # 🔹 Gradio UI
81
  with gr.Blocks() as demo:
@@ -96,7 +94,7 @@ with gr.Blocks() as demo:
96
  with gr.Row():
97
  submit_button = gr.Button("Generate Response")
98
 
99
- submit_button.click(chat_response, inputs=[user_input, top_p, top_k, temperature, max_length], outputs=[chatbot], stream=True)
100
 
101
  # 🔹 Launch the Gradio app
102
- demo.launch()
 
61
  input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
62
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
63
 
64
+ full_response = ""
65
+
66
  with torch.no_grad():
67
+ for token in model.generate(
68
  **inputs,
69
  max_length=max_length,
70
  do_sample=True,
71
  top_p=top_p,
72
  top_k=top_k,
73
+ temperature=temperature
74
+ ):
75
+ full_response += tokenizer.decode(token, skip_special_tokens=True)
76
+ yield gr.Textbox.update(value=extract_response(full_response))
 
 
 
 
77
 
78
  # 🔹 Gradio UI
79
  with gr.Blocks() as demo:
 
94
  with gr.Row():
95
  submit_button = gr.Button("Generate Response")
96
 
97
+ submit_button.click(chat_response, inputs=[user_input, top_p, top_k, temperature, max_length], outputs=[chatbot])
98
 
99
  # 🔹 Launch the Gradio app
100
+ demo.launch()