Spaces:
Paused
Paused
from AdviceGenerator import AdviceGenerator | |
from SIDetector import SIDetector | |
from IllnessClassifier import IllnessClassifier | |
from typing import List, Dict | |
class UofTearsBot(object): | |
def __init__(self, llm, threshold: float = 0.86, max_history_msgs: int = 50): | |
self.suicidality_detector = SIDetector() | |
self.illness_classifier = IllnessClassifier() | |
self.chatbot = AdviceGenerator(llm) | |
self.history: List[Dict[str, str]] = [] | |
self.FLAG = False # suicidal crisis flag | |
self.threshold = threshold | |
self.max_history_msgs = max_history_msgs | |
def safety_check(self, user_text: str): | |
suicidal, confidence = self.suicidality_detector.forward(user_text) | |
self.FLAG = (suicidal and confidence >= self.threshold) | |
if self.FLAG: | |
return "suicidal" | |
disorder, conf = self.illness_classifier.forward(user_text) | |
return disorder | |
def userCrisis(self) -> str: | |
return ( | |
"I'm really sorry you're feeling this way. Your safety matters.\n" | |
"If you are in immediate danger, please call your local emergency number now.\n" | |
"US & Canada: dial **988** (24/7)\n" | |
"International: call your local emergency services.\n" | |
"If you can, reach out to someone you trust and let them know what’s going on.\n" | |
"We can also talk through coping strategies together." | |
) | |
def _prune_history(self) -> List[Dict[str, str]]: | |
"""Keep only the last `max_history_msgs` turns.""" | |
return self.history[-self.max_history_msgs :] | |
def converse(self, user_text: str) -> str: | |
disorder = self.safety_check(user_text) | |
# store user input | |
self.history.append({"role": "user", "content": user_text}) | |
# crisis branch | |
if self.FLAG: | |
crisis_msg = self.userCrisis() | |
self.history.append({"role": "assistant", "content": crisis_msg}) | |
yield crisis_msg | |
return | |
# normal branch: stream advice tokens | |
reply_so_far = "" | |
for delta in self.chatbot.generate_advice( | |
disorder=disorder, | |
user_text=user_text, | |
history=self._prune_history(), | |
): | |
reply_so_far += delta | |
yield delta # stream to FastAPI as soon as a token arrives | |
# once stream is done, save full reply | |
self.history.append({"role": "assistant", "content": reply_so_far}) |