from AdviceGenerator import AdviceGenerator from SIDetector import SIDetector from IllnessClassifier import IllnessClassifier from typing import List, Dict class UofTearsBot(object): def __init__(self, llm, threshold: float = 0.86, max_history_msgs: int = 50): self.suicidality_detector = SIDetector() self.illness_classifier = IllnessClassifier() self.chatbot = AdviceGenerator(llm) self.history: List[Dict[str, str]] = [] self.FLAG = False # suicidal crisis flag self.threshold = threshold self.max_history_msgs = max_history_msgs def safety_check(self, user_text: str): suicidal, confidence = self.suicidality_detector.forward(user_text) self.FLAG = (suicidal and confidence >= self.threshold) if self.FLAG: return "suicidal" disorder, conf = self.illness_classifier.forward(user_text) return disorder def userCrisis(self) -> str: return ( "I'm really sorry you're feeling this way. Your safety matters.\n" "If you are in immediate danger, please call your local emergency number now.\n" "US & Canada: dial **988** (24/7)\n" "International: call your local emergency services.\n" "If you can, reach out to someone you trust and let them know what’s going on.\n" "We can also talk through coping strategies together." ) def _prune_history(self) -> List[Dict[str, str]]: """Keep only the last `max_history_msgs` turns.""" return self.history[-self.max_history_msgs :] def converse(self, user_text: str) -> str: disorder = self.safety_check(user_text) # store user input self.history.append({"role": "user", "content": user_text}) # crisis branch if self.FLAG: crisis_msg = self.userCrisis() self.history.append({"role": "assistant", "content": crisis_msg}) yield crisis_msg return # normal branch: stream advice tokens reply_so_far = "" for delta in self.chatbot.generate_advice( disorder=disorder, user_text=user_text, history=self._prune_history(), ): reply_so_far += delta yield delta # stream to FastAPI as soon as a token arrives # once stream is done, save full reply self.history.append({"role": "assistant", "content": reply_so_far})