bobpopboom commited on
Commit
81ab351
·
verified ·
1 Parent(s): 00c98e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -3
app.py CHANGED
@@ -39,6 +39,28 @@ def generate_text(prompt, max_new_tokens=128):
39
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
40
  return generated_text
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def respond(message, history, system_message, max_tokens):
43
  prompt = f"{system_message}\n"
44
  for user_msg, bot_msg in history:
@@ -46,8 +68,9 @@ def respond(message, history, system_message, max_tokens):
46
  prompt += f"User: {message}\nAssistant:"
47
 
48
  try:
49
- bot_response = generate_text(prompt, max_tokens) # Use the new function
50
- yield bot_response
 
51
  except Exception as e:
52
  print(f"Error during generation: {e}")
53
  yield "An error occurred."
@@ -59,7 +82,7 @@ demo = gr.ChatInterface(
59
  value="You are a friendly and helpful mental health chatbot.",
60
  label="System message",
61
  ),
62
- gr.Slider(minimum=1, maximum=128, value=128, step=10, label="Max new tokens"),
63
  ],
64
  )
65
 
 
39
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
40
  return generated_text
41
 
42
+ def generate_text_streaming(prompt, max_new_tokens=128):
43
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
44
+
45
+ with torch.no_grad():
46
+ for i in range(max_new_tokens):
47
+ output = model.generate(
48
+ input_ids=input_ids,
49
+ max_new_tokens=1, # Generate only 1 new token at a time
50
+ do_sample=False, # Or True for sampling
51
+ eos_token_id=tokenizer.eos_token_id,
52
+ return_dict=True, #Return a dictionary
53
+ output_scores=True #Return the scores
54
+ )
55
+
56
+ generated_token = tokenizer.decode(output.logits[0][-1].argmax(), skip_special_tokens=True) #Decode the last token only
57
+ yield generated_token #Yield the last token
58
+
59
+ input_ids = torch.cat([input_ids, output.sequences[:, -1:]], dim=-1) #Append the new token to the input
60
+
61
+ if output.sequences[0][-1] == tokenizer.eos_token_id: #Check if the end of sequence token was generated
62
+ break #Break the loop
63
+
64
  def respond(message, history, system_message, max_tokens):
65
  prompt = f"{system_message}\n"
66
  for user_msg, bot_msg in history:
 
68
  prompt += f"User: {message}\nAssistant:"
69
 
70
  try:
71
+ for token in generate_text_streaming(prompt, max_tokens): #Iterate over the generator
72
+ yield token #Yield each token individually
73
+
74
  except Exception as e:
75
  print(f"Error during generation: {e}")
76
  yield "An error occurred."
 
82
  value="You are a friendly and helpful mental health chatbot.",
83
  label="System message",
84
  ),
85
+ gr.Slider(minimum=1, maximum=128, value=32, step=10, label="Max new tokens"),
86
  ],
87
  )
88