42Cummer commited on
Commit
b48d8e8
Β·
verified Β·
1 Parent(s): e556488

Streaming Replies

Browse files
Files changed (3) hide show
  1. AdviceGenerator.py +13 -14
  2. UofTearsBot.py +15 -17
  3. app.py +17 -10
AdviceGenerator.py CHANGED
@@ -34,7 +34,7 @@ class AdviceGenerator(object):
34
  max_tokens: int = 600, # give enough headroom
35
  temperature: float = 0.6,
36
  top_p: float = 0.9,
37
- ) -> Dict[str, str]:
38
 
39
  msgs = [self.role]
40
 
@@ -53,16 +53,15 @@ class AdviceGenerator(object):
53
  "Follow the system instructions strictly. Do NOT ask vague questions first."
54
  ),
55
  })
56
-
57
- try:
58
- resp = self.llm.create_chat_completion(
59
- messages=msgs,
60
- temperature=temperature,
61
- top_p=top_p,
62
- max_tokens=max_tokens,
63
- stream=False,
64
- )
65
- text = resp["choices"][0]["message"]["content"].strip()
66
- return {"text": text}
67
- except Exception as e:
68
- return {"text": f"I'm here to listen. Could you tell me more about how \"{user_text}\" is affecting you?"}
 
34
  max_tokens: int = 600, # give enough headroom
35
  temperature: float = 0.6,
36
  top_p: float = 0.9,
37
+ ):
38
 
39
  msgs = [self.role]
40
 
 
53
  "Follow the system instructions strictly. Do NOT ask vague questions first."
54
  ),
55
  })
56
+ stream = self.llm.create_chat_completion(
57
+ messages=msgs,
58
+ temperature=temperature,
59
+ top_p=top_p,
60
+ max_tokens=max_tokens,
61
+ stream=True,
62
+ )
63
+ for chunk in stream:
64
+ if "choices" in chunk:
65
+ delta = chunk["choices"][0]["delta"].get("content", "")
66
+ if delta:
67
+ yield delta
 
UofTearsBot.py CHANGED
@@ -4,10 +4,10 @@ from IllnessClassifier import IllnessClassifier
4
  from typing import List, Dict
5
 
6
  class UofTearsBot(object):
7
- def __init__(self, llm, threshold: float = 0.86, max_history_msgs: int = 50):
8
  self.suicidality_detector = SIDetector()
9
  self.illness_classifier = IllnessClassifier()
10
- self.chatbot = AdviceGenerator(llm)
11
  self.history: List[Dict[str, str]] = []
12
  self.FLAG = False # suicidal crisis flag
13
  self.threshold = threshold
@@ -37,24 +37,22 @@ class UofTearsBot(object):
37
 
38
  def converse(self, user_text: str) -> str:
39
  disorder = self.safety_check(user_text)
40
-
41
- # store user text into history
42
  self.history.append({"role": "user", "content": user_text})
43
-
44
  if self.FLAG:
45
- # crisis flow: respond with fixed crisis message only
46
  crisis_msg = self.userCrisis()
47
  self.history.append({"role": "assistant", "content": crisis_msg})
48
- return crisis_msg
49
-
50
- # normal advice generation
51
- pruned_history = self._prune_history()
52
- advice = self.chatbot.generate_advice(
53
  disorder=disorder,
54
  user_text=user_text,
55
- history=pruned_history,
56
- )['text']
57
-
58
- # add bot response to history
59
- self.history.append({"role": "assistant", "content": advice})
60
- return advice
 
4
  from typing import List, Dict
5
 
6
  class UofTearsBot(object):
7
+ def __init__(self, threshold: float = 0.86, max_history_msgs: int = 50):
8
  self.suicidality_detector = SIDetector()
9
  self.illness_classifier = IllnessClassifier()
10
+ self.chatbot = AdviceGenerator()
11
  self.history: List[Dict[str, str]] = []
12
  self.FLAG = False # suicidal crisis flag
13
  self.threshold = threshold
 
37
 
38
  def converse(self, user_text: str) -> str:
39
  disorder = self.safety_check(user_text)
40
+ # store user input
 
41
  self.history.append({"role": "user", "content": user_text})
42
+ # crisis branch
43
  if self.FLAG:
 
44
  crisis_msg = self.userCrisis()
45
  self.history.append({"role": "assistant", "content": crisis_msg})
46
+ yield crisis_msg
47
+ return
48
+ # normal branch: stream advice tokens
49
+ reply_so_far = ""
50
+ for delta in self.chatbot.generate_advice(
51
  disorder=disorder,
52
  user_text=user_text,
53
+ history=self._prune_history(),
54
+ ):
55
+ reply_so_far += delta
56
+ yield delta # stream to FastAPI as soon as a token arrives
57
+ # once stream is done, save full reply
58
+ self.history.append({"role": "assistant", "content": reply_so_far})
app.py CHANGED
@@ -5,7 +5,7 @@ import dotenv
5
 
6
  import torch
7
  from fastapi import FastAPI, HTTPException, Request
8
- from fastapi.responses import JSONResponse, HTMLResponse
9
  from pydantic import BaseModel
10
  from llama_cpp import Llama
11
  from huggingface_hub import hf_hub_download, login
@@ -21,9 +21,9 @@ from transformers import (
21
 
22
  from UofTearsBot import UofTearsBot
23
 
24
- MODEL_REPO="bartowski/Mistral-7B-Instruct-v0.3-GGUF"
25
- MODEL_FILE="Mistral-7B-Instruct-v0.3-Q4_K_M.gguf"
26
- CHAT_FORMAT="mistral-instruct"
27
 
28
  dotenv.load_dotenv()
29
  login(token=os.getenv("HF_TOKEN"))
@@ -51,18 +51,25 @@ class ChatRequest(BaseModel):
51
  user_id: str
52
  user_text: str
53
 
 
54
  @app.post("/chat")
55
  async def chat(request: ChatRequest):
56
  try:
57
  if request.user_id not in chatbots:
58
  chatbots[request.user_id] = UofTearsBot(llm)
59
  current_bot = chatbots[request.user_id]
60
- print("[INFO] Model is generating response...", flush=True)
61
- response = current_bot.converse(request.user_text)
62
- return JSONResponse(content={"response": response, "history": current_bot.history})
 
 
 
 
 
 
63
  except Exception as e:
64
  import traceback
65
- traceback.print_exc() # logs full stack trace to HF Logs
66
  return JSONResponse(
67
  status_code=500,
68
  content={"error": str(e)}
@@ -72,7 +79,7 @@ async def chat(request: ChatRequest):
72
  @app.get("/", response_class=HTMLResponse)
73
  async def home():
74
  return "<h1>App is running πŸš€</h1>"
75
-
76
 
77
  if __name__ == "__main__":
78
- uvicorn.run(app, host="0.0.0.0", port=7860) # huggingface port
 
5
 
6
  import torch
7
  from fastapi import FastAPI, HTTPException, Request
8
+ from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse
9
  from pydantic import BaseModel
10
  from llama_cpp import Llama
11
  from huggingface_hub import hf_hub_download, login
 
21
 
22
  from UofTearsBot import UofTearsBot
23
 
24
+ MODEL_REPO = "bartowski/Mistral-7B-Instruct-v0.3-GGUF"
25
+ MODEL_FILE = "Mistral-7B-Instruct-v0.3-Q4_K_M.gguf"
26
+ CHAT_FORMAT = "mistral-instruct"
27
 
28
  dotenv.load_dotenv()
29
  login(token=os.getenv("HF_TOKEN"))
 
51
  user_id: str
52
  user_text: str
53
 
54
+
55
  @app.post("/chat")
56
  async def chat(request: ChatRequest):
57
  try:
58
  if request.user_id not in chatbots:
59
  chatbots[request.user_id] = UofTearsBot(llm)
60
  current_bot = chatbots[request.user_id]
61
+
62
+ def token_generator():
63
+ print("[INFO] Model is streaming response...", flush=True)
64
+ for token in current_bot.converse(request.user_text):
65
+ yield token
66
+ print("[INFO] Model finished streaming βœ…", flush=True)
67
+
68
+ return StreamingResponse(token_generator(), media_type="text/plain")
69
+
70
  except Exception as e:
71
  import traceback
72
+ traceback.print_exc() # logs to HF logs
73
  return JSONResponse(
74
  status_code=500,
75
  content={"error": str(e)}
 
79
  @app.get("/", response_class=HTMLResponse)
80
  async def home():
81
  return "<h1>App is running πŸš€</h1>"
82
+
83
 
84
  if __name__ == "__main__":
85
+ uvicorn.run(app, host="0.0.0.0", port=7860) # huggingface port