jocko commited on
Commit
e146235
Β·
1 Parent(s): fd7833c

initial commit

Browse files
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
  altair
2
  pandas
3
- streamlit
 
 
 
 
 
 
1
  altair
2
  pandas
3
+ streamlit
4
+ torch
5
+ transformers
6
+ sentence-transformers
7
+ datasets
8
+ openai
src/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit
2
+ openai>=1.2.0
3
+ sentence-transformers
4
+ torch
src/runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.10.12
src/streamlit_app.py CHANGED
@@ -1,40 +1,145 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
2
+
3
+ # βœ… Set all relevant cache directories to a writable location
4
+ os.environ["HF_HOME"] = "/tmp/cache"
5
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/cache/transformers"
6
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/cache/sentence_transformers"
7
+ os.environ["HF_DATASETS_CACHE"] = "/tmp/cache/hf_datasets"
8
+ os.environ["TORCH_HOME"] = "/tmp/cache/torch"
9
+
10
+ # βœ… Create the directories if they don't exist
11
+ for path in [
12
+ "/tmp/cache",
13
+ "/tmp/cache/transformers",
14
+ "/tmp/cache/sentence_transformers",
15
+ "/tmp/cache/hf_datasets",
16
+ "/tmp/cache/torch"
17
+ ]:
18
+ os.makedirs(path, exist_ok=True)
19
+ import json
20
+ import torch
21
+ import openai
22
+ import os
23
+ from sentence_transformers import SentenceTransformer, util
24
  import streamlit as st
25
+ from pathlib import Path
26
+
27
+ # === CONFIG ===
28
+ # Set the API key
29
+ client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
30
+ #openai.api_key = os.getenv("OPENAI_API_KEY")
31
+ # REMEDI_PATH = "ReMeDi-base.json"
32
+ BASE_DIR = Path(__file__).parent
33
+ REMEDI_PATH = BASE_DIR / "ReMeDi-base.json"
34
+
35
+ # Check if file exists
36
+ if not REMEDI_PATH.exists():
37
+ raise FileNotFoundError(f"❌ File not found: {REMEDI_PATH}")
38
+
39
+ # Load the file
40
+ with open(REMEDI_PATH, "r", encoding="utf-8") as f:
41
+ data = json.load(f)
42
+
43
+ # === LOAD MODEL ===
44
+ @st.cache_resource
45
+ def load_model():
46
+ return SentenceTransformer("all-MiniLM-L6-v2")
47
+ #return model
48
+
49
+ @st.cache_resource
50
+ def load_data():
51
+ with open(REMEDI_PATH, "r", encoding="utf-8") as f:
52
+ data = json.load(f)
53
+ dialogue_pairs = []
54
+ for conversation in data:
55
+ turns = conversation["information"]
56
+ for i in range(len(turns)-1):
57
+ if turns[i]["role"] == "patient" and turns[i+1]["role"] == "doctor":
58
+ dialogue_pairs.append({
59
+ "patient": turns[i]["sentence"],
60
+ "doctor": turns[i+1]["sentence"]
61
+ })
62
+ return dialogue_pairs
63
+
64
+ @st.cache_data
65
+ def build_embeddings(dialogue_pairs, _model):
66
+ patient_sentences = [pair["patient"] for pair in dialogue_pairs]
67
+ embeddings = _model.encode(patient_sentences, convert_to_tensor=True)
68
+ return embeddings
69
+
70
+ # === TRANSLATE USING GPT ===
71
+ def translate_to_english(chinese_text):
72
+ prompt = f"Translate the following Chinese medical response to English:\n\n{chinese_text}"
73
+ try:
74
+ response = client.chat.completions.create(
75
+ model="gpt-4",
76
+ messages=[{"role": "user", "content": prompt}],
77
+ temperature=0.2
78
+ )
79
+ return response.choices[0].message.content
80
+
81
+ #return response.choices[0].message["content"].strip()
82
+ except Exception as e:
83
+ return f"Translation failed: {str(e)}"
84
+
85
+ def gpt_direct_response(user_input):
86
+ prompt = f"You are a knowledgeable and compassionate medical assistant. Answer the following patient question clearly and concisely:\n\n{user_input}"
87
+ try:
88
+ response = client.chat.completions.create(
89
+ model="gpt-4", # or "gpt-3.5-turbo" to save credits
90
+ messages=[{"role": "user", "content": prompt}],
91
+ temperature=0.5
92
+ )
93
+ return response.choices[0].message.content
94
+ except Exception as e:
95
+ return f"GPT response failed: {str(e)}"
96
+
97
+
98
+ # === CHATBOT FUNCTION ===
99
+ def chatbot_response(user_input, _model, dialogue_pairs, patient_embeddings, top_k=1):
100
+ user_embedding = _model.encode(user_input, convert_to_tensor=True)
101
+ similarities = util.cos_sim(user_embedding, patient_embeddings)[0]
102
+ top_idx = torch.topk(similarities, k=top_k).indices[0].item()
103
+
104
+ match = dialogue_pairs[top_idx]
105
+ translated = translate_to_english(match["doctor"])
106
+
107
+ return {
108
+ "matched_question": match["patient"],
109
+ "original_response": match["doctor"],
110
+ "translated_response": translated
111
+ }
112
+
113
+ # === MAIN APP ===
114
+ st.set_page_config(page_title="Dr_Q_bot", layout="centered")
115
+ st.title("🩺 Dr_Q_bot - Medical Chatbot")
116
+ st.write("Ask about a symptom and get an example doctor response (translated from Chinese).")
117
+
118
+ # Load resources
119
+ model = load_model()
120
+ dialogue_pairs = load_data()
121
+ patient_embeddings = build_embeddings(dialogue_pairs, model)
122
+
123
+ # Chat UI
124
+ user_input = st.text_input("Describe your symptom:")
125
+
126
+ if st.button("Submit") and user_input:
127
+ with st.spinner("Thinking..."):
128
+ result = chatbot_response(user_input, model, dialogue_pairs, patient_embeddings)
129
+ gpt_response = gpt_direct_response(user_input)
130
+
131
+ st.markdown("### πŸ§‘β€βš•οΈ Closest Patient Question")
132
+ st.write(result["matched_question"])
133
+
134
+ st.markdown("### πŸ‡¨πŸ‡³ Original Doctor Response (Chinese)")
135
+ st.write(result["original_response"])
136
+
137
+ st.markdown("### 🌐 Translated Doctor Response (English)")
138
+ st.success(result["translated_response"])
139
+
140
+ st.markdown("### πŸ’¬ GPT Doctor Response (AI-generated)")
141
+ st.info(gpt_response)
142
+
143
 
144
+ st.markdown("---")
145
+ st.warning("This chatbot uses real dialogue data for research and educational use only. Not a substitute for professional medical advice.")