42Cummer commited on
Commit
22d76f2
·
verified ·
1 Parent(s): 26c92d2

Uploaded files from Cursor

Browse files
Files changed (5) hide show
  1. AdviceGenerator.py +68 -0
  2. IllnessClassifier.py +19 -0
  3. SIDetector.py +19 -0
  4. UofTearsBot.py +60 -0
  5. app.py +68 -0
AdviceGenerator.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForSequenceClassification,
5
+ AutoModelForCausalLM,
6
+ pipeline
7
+ )
8
+
9
+ class AdviceGenerator(object):
10
+ def __init__(self, llm):
11
+ self.llm = llm
12
+ self.role = {
13
+ "role": "system",
14
+ "content": (
15
+ "You are a supportive assistant (not a mental health professional). "
16
+ "Be concrete and tailor every response to the user's situation. "
17
+ "Requirements:\n"
18
+ "1) Begin with ONE empathetic sentence that mentions a key detail from the user's text (name, event, constraint).\n"
19
+ "2) Then give 3–5 numbered, practical tips. Each tip must reference the user's situation (use names/keywords when present).\n"
20
+ "3) If the user's text involves talking to someone (crush, friend, teacher, parent, boss), include a short **Script** block "
21
+ " with two options (in-person and text), customized with any names from the user's text.\n"
22
+ "4) Add a **Try now (2 min)** micro-step.\n"
23
+ "5) End with ONE targeted follow-up question that references the user's situation.\n"
24
+ "Avoid platitudes and generic advice; avoid clinical instructions."
25
+ ),
26
+ }
27
+
28
+ def generate_advice(
29
+ self,
30
+ disorder: str,
31
+ user_text: str,
32
+ history: List[Dict[str, str]] = None,
33
+ max_history_msgs: int = 50,
34
+ max_tokens: int = 600, # give enough headroom
35
+ temperature: float = 0.6,
36
+ top_p: float = 0.9,
37
+ ) -> Dict[str, str]:
38
+
39
+ msgs = [self.role]
40
+
41
+ # preserve rolling chat history if available
42
+ if history:
43
+ msgs.extend(history[-max_history_msgs:])
44
+
45
+ # always append the new user input
46
+ msgs.append({
47
+ "role": "user",
48
+ "content": (
49
+ "Use the exact situation below to personalize your advice. "
50
+ "Extract the main goal or barrier from the text and ground each tip in it.\n\n"
51
+ f"Detected context: {disorder}\n"
52
+ f"User text: {user_text}\n\n"
53
+ "Follow the system instructions strictly. Do NOT ask vague questions first."
54
+ ),
55
+ })
56
+
57
+ try:
58
+ resp = self.llm.create_chat_completion(
59
+ messages=msgs,
60
+ temperature=temperature,
61
+ top_p=top_p,
62
+ max_tokens=max_tokens,
63
+ stream=False,
64
+ )
65
+ text = resp["choices"][0]["message"]["content"].strip()
66
+ return {"text": text}
67
+ except Exception as e:
68
+ return {"text": f"I'm here to listen. Could you tell me more about how \"{user_text}\" is affecting you?"}
IllnessClassifier.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ( # pylint: disable=import-error
2
+ AutoTokenizer,
3
+ AutoModelForSequenceClassification,
4
+ AutoModelForCausalLM,
5
+ pipeline
6
+ )
7
+
8
+ import logging
9
+
10
+ class IllnessClassifier(object):
11
+ def __init__(self):
12
+ self.classifier = pipeline("text-classification", model="dsuram/distilbert-mentalhealth-classifier")
13
+
14
+ def forward(self, text: str):
15
+ output = self.classifier(text)[0]
16
+ disorder = output['label']
17
+ confidence = output['score']
18
+ logging.info(f"Disorder: {disorder}, Confidence: {confidence}")
19
+ return disorder, confidence
SIDetector.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ( # pylint: disable=import-error
2
+ AutoTokenizer,
3
+ AutoModelForSequenceClassification,
4
+ AutoModelForCausalLM,
5
+ pipeline
6
+ )
7
+
8
+ import logging
9
+
10
+ class SIDetector(object):
11
+ def __init__(self):
12
+ self.classifier = pipeline("sentiment-analysis", model="sentinet/suicidality")
13
+
14
+ def forward(self, text: str):
15
+ output = self.classifier(text)[0]
16
+ suicidal = True if output['label'] == 'LABEL_1' else False
17
+ confidence = output['score']
18
+ logging.info(f"Suicidal: {suicidal}, Confidence: {confidence}")
19
+ return suicidal, confidence
UofTearsBot.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from AdviceGenerator import AdviceGenerator
2
+ from SIDetector import SIDetector
3
+ from IllnessClassifier import IllnessClassifier
4
+ from typing import List, Dict
5
+
6
+ class UofTearsBot(object):
7
+ def __init__(self, threshold: float = 0.86, max_history_msgs: int = 50):
8
+ self.suicidality_detector = SIDetector()
9
+ self.illness_classifier = IllnessClassifier()
10
+ self.chatbot = AdviceGenerator()
11
+ self.history: List[Dict[str, str]] = []
12
+ self.FLAG = False # suicidal crisis flag
13
+ self.threshold = threshold
14
+ self.max_history_msgs = max_history_msgs
15
+
16
+ def safety_check(self, user_text: str):
17
+ suicidal, confidence = self.suicidality_detector.forward(user_text)
18
+ self.FLAG = (suicidal and confidence >= self.threshold)
19
+ if self.FLAG:
20
+ return "suicidal"
21
+ disorder, conf = self.illness_classifier.forward(user_text)
22
+ return disorder
23
+
24
+ def userCrisis(self) -> str:
25
+ return (
26
+ "I'm really sorry you're feeling this way. Your safety matters.\n"
27
+ "If you are in immediate danger, please call your local emergency number now.\n"
28
+ "US & Canada: dial **988** (24/7)\n"
29
+ "International: call your local emergency services.\n"
30
+ "If you can, reach out to someone you trust and let them know what’s going on.\n"
31
+ "We can also talk through coping strategies together."
32
+ )
33
+
34
+ def _prune_history(self) -> List[Dict[str, str]]:
35
+ """Keep only the last `max_history_msgs` turns."""
36
+ return self.history[-self.max_history_msgs :]
37
+
38
+ def converse(self, user_text: str) -> str:
39
+ disorder = self.safety_check(user_text)
40
+
41
+ # store user text into history
42
+ self.history.append({"role": "user", "content": user_text})
43
+
44
+ if self.FLAG:
45
+ # crisis flow: respond with fixed crisis message only
46
+ crisis_msg = self.userCrisis()
47
+ self.history.append({"role": "assistant", "content": crisis_msg})
48
+ return crisis_msg
49
+
50
+ # normal advice generation
51
+ pruned_history = self._prune_history()
52
+ advice = self.chatbot.generate_advice(
53
+ disorder=disorder,
54
+ user_text=user_text,
55
+ history=pruned_history,
56
+ )['text']
57
+
58
+ # add bot response to history
59
+ self.history.append({"role": "assistant", "content": advice})
60
+ return advice
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict
3
+ import logging
4
+ import dotenv
5
+
6
+ import torch
7
+ from fastapi import FastAPI, HTTPException, Request
8
+ from fastapi.responses import JSONResponse, HTMLResponse
9
+ from pydantic import BaseModel
10
+ from llama_cpp import Llama
11
+ from huggingface_hub import hf_hub_download, login
12
+
13
+ import uvicorn
14
+
15
+ from transformers import (
16
+ AutoTokenizer,
17
+ AutoModelForSequenceClassification,
18
+ AutoModelForCausalLM,
19
+ pipeline
20
+ )
21
+
22
+ from UofTearsBot import UofTearsBot
23
+
24
+ MODEL_REPO="bartowski/Mistral-7B-Instruct-v0.3-GGUF"
25
+ MODEL_FILE="Mistral-7B-Instruct-v0.3-Q4_K_M.gguf"
26
+ CHAT_FORMAT="mistral-instruct"
27
+
28
+ dotenv.load_dotenv()
29
+ login(token=os.getenv("HF_TOKEN"))
30
+
31
+ MODEL_PATH = hf_hub_download(
32
+ repo_id=MODEL_REPO,
33
+ filename=MODEL_FILE,
34
+ local_dir="/tmp/models",
35
+ local_dir_use_symlinks=False,
36
+ )
37
+
38
+ llm = Llama(
39
+ model_path=MODEL_PATH,
40
+ n_ctx=int(os.getenv("N_CTX", "4096")),
41
+ n_threads=os.cpu_count() or 4,
42
+ n_batch=int(os.getenv("N_BATCH", "256")),
43
+ chat_format=CHAT_FORMAT,
44
+ )
45
+
46
+ # Start the FastAPI app
47
+ app = FastAPI()
48
+ chatbots: Dict[str, UofTearsBot] = {}
49
+
50
+ class ChatRequest(BaseModel):
51
+ user_id: str
52
+ user_text: str
53
+
54
+ @app.post("/chat")
55
+ async def chat(request: ChatRequest):
56
+ if request.user_id not in chatbots:
57
+ chatbots[request.user_id] = UofTearsBot(llm)
58
+ current_bot = chatbots[request.user_id]
59
+ response = current_bot.converse(request.user_text)
60
+ return JSONResponse(content={"response": response, "history": current_bot.history})
61
+
62
+ @app.get("/", response_class=HTMLResponse)
63
+ async def home():
64
+ return "<h1>App is running 🚀</h1>"
65
+
66
+
67
+ if __name__ == "__main__":
68
+ uvicorn.run(app, host="0.0.0.0", port=7860) # huggingface port