UofTearsBotAPI / UofTearsBot.py
42Cummer's picture
Update UofTearsBot.py
bf06b09 verified
raw
history blame
2.48 kB
from AdviceGenerator import AdviceGenerator
from SIDetector import SIDetector
from IllnessClassifier import IllnessClassifier
from typing import List, Dict
class UofTearsBot(object):
def __init__(self, llm, threshold: float = 0.86, max_history_msgs: int = 50):
self.suicidality_detector = SIDetector()
self.illness_classifier = IllnessClassifier()
self.chatbot = AdviceGenerator(llm)
self.history: List[Dict[str, str]] = []
self.FLAG = False # suicidal crisis flag
self.threshold = threshold
self.max_history_msgs = max_history_msgs
def safety_check(self, user_text: str):
suicidal, confidence = self.suicidality_detector.forward(user_text)
self.FLAG = (suicidal and confidence >= self.threshold)
if self.FLAG:
return "suicidal"
disorder, conf = self.illness_classifier.forward(user_text)
return disorder
def userCrisis(self) -> str:
return (
"I'm really sorry you're feeling this way. Your safety matters.\n"
"If you are in immediate danger, please call your local emergency number now.\n"
"US & Canada: dial **988** (24/7)\n"
"International: call your local emergency services.\n"
"If you can, reach out to someone you trust and let them know what’s going on.\n"
"We can also talk through coping strategies together."
)
def _prune_history(self) -> List[Dict[str, str]]:
"""Keep only the last `max_history_msgs` turns."""
return self.history[-self.max_history_msgs :]
def converse(self, user_text: str) -> str:
disorder = self.safety_check(user_text)
# store user input
self.history.append({"role": "user", "content": user_text})
# crisis branch
if self.FLAG:
crisis_msg = self.userCrisis()
self.history.append({"role": "assistant", "content": crisis_msg})
yield crisis_msg
return
# normal branch: stream advice tokens
reply_so_far = ""
for delta in self.chatbot.generate_advice(
disorder=disorder,
user_text=user_text,
history=self._prune_history(),
):
reply_so_far += delta
yield delta # stream to FastAPI as soon as a token arrives
# once stream is done, save full reply
self.history.append({"role": "assistant", "content": reply_so_far})