jocko commited on
Commit
0a89a37
Β·
1 Parent(s): 4a259f2

copy updates of mult modal

Browse files
Files changed (2) hide show
  1. README.md +3 -4
  2. src/streamlit_app.py +103 -140
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Dr Q Bot Multimodal
3
  emoji: πŸš€
4
  colorFrom: red
5
  colorTo: red
@@ -8,12 +8,11 @@ app_port: 8501
8
  tags:
9
  - streamlit
10
  pinned: false
11
- short_description: multimodal
12
  ---
13
-
14
  # Welcome to Streamlit!
15
 
16
  Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
17
 
18
  If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
1
  ---
2
+ title: Dr Q
3
  emoji: πŸš€
4
  colorFrom: red
5
  colorTo: red
 
8
  tags:
9
  - streamlit
10
  pinned: false
11
+ short_description: Multimodal medical chatbot
12
  ---
 
13
  # Welcome to Streamlit!
14
 
15
  Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
16
 
17
  If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
18
+ forums](https://discuss.streamlit.io).
src/streamlit_app.py CHANGED
@@ -1,164 +1,127 @@
 
 
 
 
1
  import os
2
 
3
- # βœ… Set all relevant cache directories to a writable location
4
- os.environ["HF_HOME"] = "/tmp/cache"
5
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/cache/transformers"
6
- os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/cache/sentence_transformers"
7
- os.environ["HF_DATASETS_CACHE"] = "/tmp/cache/hf_datasets"
8
- os.environ["TORCH_HOME"] = "/tmp/cache/torch"
9
-
10
- # βœ… Create the directories if they don't exist
11
- for path in [
12
- "/tmp/cache",
13
- "/tmp/cache/transformers",
14
- "/tmp/cache/sentence_transformers",
15
- "/tmp/cache/hf_datasets",
16
- "/tmp/cache/torch"
17
- ]:
18
- os.makedirs(path, exist_ok=True)
19
- import json
20
  import torch
21
- import openai
22
- import os
23
  from sentence_transformers import SentenceTransformer, util
24
- import streamlit as st
25
- from pathlib import Path
26
-
27
- # === CONFIG ===
28
- # Set the API key
29
- client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
30
- # openai.api_key = os.getenv("OPENAI_API_KEY")
31
- # REMEDI_PATH = "ReMeDi-base.json"
32
- BASE_DIR = Path(__file__).parent
33
- REMEDI_PATH = BASE_DIR / "ReMeDi-base.json"
34
-
35
- # Check if file exists
36
- if not REMEDI_PATH.exists():
37
- raise FileNotFoundError(f"❌ File not found: {REMEDI_PATH}")
38
-
39
- # Load the file
40
- with open(REMEDI_PATH, "r", encoding="utf-8") as f:
41
- data = json.load(f)
42
-
43
-
44
- # === LOAD MODEL ===
45
- @st.cache_resource
46
- def load_model():
47
- return SentenceTransformer("all-MiniLM-L6-v2")
48
- # return model
49
-
50
-
51
- @st.cache_resource
52
- def load_data():
53
- with open(REMEDI_PATH, "r", encoding="utf-8") as f:
54
- data = json.load(f)
55
- dialogue_pairs = []
56
- for conversation in data:
57
- turns = conversation["information"]
58
- for i in range(len(turns) - 1):
59
- if turns[i]["role"] == "patient" and turns[i + 1]["role"] == "doctor":
60
- dialogue_pairs.append({
61
- "patient": turns[i]["sentence"],
62
- "doctor": turns[i + 1]["sentence"]
63
- })
64
- return dialogue_pairs
65
-
66
-
67
- @st.cache_data
68
- def build_embeddings(dialogue_pairs, _model):
69
- patient_sentences = [pair["patient"] for pair in dialogue_pairs]
70
- embeddings = _model.encode(patient_sentences, convert_to_tensor=True)
71
- return embeddings
72
-
73
-
74
- # === TRANSLATE USING GPT ===
75
- def translate_to_english(chinese_text):
76
- prompt = f"Translate the following Chinese medical response to English:\n\n{chinese_text}"
77
- try:
78
- response = client.chat.completions.create(
79
- model="gpt-4",
80
- messages=[{"role": "user", "content": prompt}],
81
- temperature=0.2
82
- )
83
- return response.choices[0].message.content
84
 
85
- # return response.choices[0].message["content"].strip()
86
- except Exception as e:
87
- return f"Translation failed: {str(e)}"
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- def gpt_direct_response(user_input):
91
- prompt = f"You are a knowledgeable and compassionate medical assistant. Answer the following patient question clearly and concisely:\n\n{user_input}"
92
- try:
93
- response = client.chat.completions.create(
94
- model="gpt-4", # or "gpt-3.5-turbo" to save credits
95
- messages=[{"role": "user", "content": prompt}],
96
- temperature=0.5
97
- )
98
- return response.choices[0].message.content
99
- except Exception as e:
100
- return f"GPT response failed: {str(e)}"
101
 
102
 
103
- # === CHATBOT FUNCTION ===
104
- def chatbot_response(user_input, _model, dialogue_pairs, patient_embeddings, top_k=1):
105
- user_embedding = _model.encode(user_input, convert_to_tensor=True)
106
- similarities = util.cos_sim(user_embedding, patient_embeddings)[0]
107
- top_score, top_idx = torch.topk(similarities, k=1)
108
- top_score = top_score.item()
109
- top_idx = torch.topk(similarities, k=top_k).indices[0].item()
110
 
111
- match = dialogue_pairs[top_idx]
112
- translated = translate_to_english(match["doctor"])
 
 
 
 
 
 
 
113
 
114
- return {
115
- "matched_question": match["patient"],
116
- "original_response": match["doctor"],
117
- "translated_response": translated
118
- # "similarity_score": top_score
119
- }
120
 
 
 
121
 
122
- # === MAIN APP ===
123
- st.set_page_config(page_title="Dr_Q_bot", layout="centered")
124
- st.title("🩺 Dr_Q_bot - Medical Chatbot")
125
- st.write("Ask about a symptom and get an example doctor response (translated from Chinese).")
126
 
127
- # Load resources
128
- model = load_model()
129
- dialogue_pairs = load_data()
130
- patient_embeddings = build_embeddings(dialogue_pairs, model)
131
 
132
- # Chat UI
133
- user_input = st.text_input("Describe your symptom:")
 
 
134
 
135
- if st.button("Submit") and user_input:
136
- with st.spinner("Thinking..."):
137
- result = chatbot_response(user_input, model, dialogue_pairs, patient_embeddings)
138
- gpt_response = gpt_direct_response(user_input)
139
 
140
- st.markdown("## βœ… GPT-4 Doctor's Response")
141
- st.success(gpt_response)
142
 
143
- # if torch.max(similarities).item() < 0.4:
144
 
145
- st.markdown("## πŸ” Example Historical Dialogue")
146
- st.markdown("### πŸ§‘β€βš•οΈ Closest Patient Question")
147
- st.write(result["matched_question"])
 
148
 
149
- st.markdown("### πŸ‡¨πŸ‡³ Original Doctor Response (Chinese)")
150
- st.write(result["original_response"])
 
 
 
151
 
152
- st.markdown("### 🌐 Translated Doctor Response (English)")
153
- st.success(result["translated_response"])
154
- # else:
155
- # st.warning("No close match found in dataset. Using GPT response only.")
156
 
157
- # st.markdown("### πŸ’¬ GPT Doctor Response (AI-generated)")
158
- # st.info(gpt_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- # Skip dataset result
 
 
161
 
162
- st.markdown("---")
163
- st.warning(
164
- "This chatbot uses real dialogue data for research and educational use only. Not a substitute for professional medical advice.")
 
1
+ # ================================
2
+ # βœ… Cache-Safe Multimodal App
3
+ # ================================
4
+
5
  import os
6
 
7
+ # ====== Force all cache dirs to /tmp (writable in most environments) ======
8
+ CACHE_BASE = "/tmp/cache"
9
+ os.environ["HF_HOME"] = f"{CACHE_BASE}/hf_home"
10
+ os.environ["TRANSFORMERS_CACHE"] = f"{CACHE_BASE}/transformers"
11
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = f"{CACHE_BASE}/sentence_transformers"
12
+ os.environ["HF_DATASETS_CACHE"] = f"{CACHE_BASE}/hf_datasets"
13
+ os.environ["TORCH_HOME"] = f"{CACHE_BASE}/torch"
14
+ os.environ["STREAMLIT_CACHE_DIR"] = f"{CACHE_BASE}/streamlit_cache"
15
+ os.environ["STREAMLIT_STATIC_DIR"] = f"{CACHE_BASE}/streamlit_static"
16
+
17
+ # Create the directories before imports
18
+ for path in os.environ.values():
19
+ if path.startswith(CACHE_BASE):
20
+ os.makedirs(path, exist_ok=True)
21
+
22
+ # ====== Imports ======
23
+ import streamlit as st
24
  import torch
 
 
25
  from sentence_transformers import SentenceTransformer, util
26
+ from transformers import CLIPProcessor, CLIPModel
27
+ from datasets import load_dataset, get_dataset_split_names
28
+ from PIL import Image
29
+ import openai
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # ========== πŸ”‘ API Key ==========
32
+ openai.api_key = os.getenv("OPENAI_API_KEY")
 
33
 
34
+ # ========== πŸ“₯ Load Models ==========
35
+ @st.cache_resource(show_spinner=False)
36
+ def load_models():
37
+ clip_model = CLIPModel.from_pretrained(
38
+ "openai/clip-vit-base-patch32",
39
+ cache_dir=os.environ["TRANSFORMERS_CACHE"]
40
+ )
41
+ clip_processor = CLIPProcessor.from_pretrained(
42
+ "openai/clip-vit-base-patch32",
43
+ cache_dir=os.environ["TRANSFORMERS_CACHE"]
44
+ )
45
+ text_model = SentenceTransformer(
46
+ "all-MiniLM-L6-v2",
47
+ cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"]
48
+ )
49
+ return clip_model, clip_processor, text_model
50
 
51
+ clip_model, clip_processor, text_model = load_models()
52
+
53
+ # ========== πŸ“₯ Load Dataset ==========
54
+ @st.cache_resource(show_spinner=False)
 
 
 
 
 
 
 
55
 
56
 
 
 
 
 
 
 
 
57
 
58
+ def load_medical_data():
59
+ available_splits = get_dataset_split_names("univanxx/3mdbench")
60
+ split_to_use = "train" if "train" in available_splits else available_splits[0]
61
+ dataset = load_dataset(
62
+ "univanxx/3mdbench",
63
+ split=split_to_use,
64
+ cache_dir=os.environ["HF_DATASETS_CACHE"]
65
+ )
66
+ return dataset
67
 
68
+ data = load_medical_data()
 
 
 
 
 
69
 
70
+ # Temporary debug display
71
+ #st.write("Dataset columns:", data.features.keys())
72
 
73
+ # After seeing the real column name, let's say it's "text" instead of "description":
74
+ text_field = "text" if "text" in data.features else list(data.features.keys())[0]
 
 
75
 
76
+ # Then use dynamic access:
77
+ #text_embeddings = embed_texts(data[text_field])
 
 
78
 
79
+ # ========== 🧠 Embedding Function ==========
80
+ @st.cache_data(show_spinner=False)
81
+ def embed_texts(_texts):
82
+ return text_model.encode(_texts, convert_to_tensor=True)
83
 
84
+ # Pick which text column to use
85
+ TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs
 
 
86
 
87
+ # ========== πŸ§‘β€βš•οΈ App UI ==========
88
+ st.title("🩺 Multimodal Medical Chatbot")
89
 
90
+ query = st.text_input("Enter your medical question or symptom description:")
91
 
92
+ if query:
93
+ with st.spinner("Searching medical cases..."):
94
+ text_embeddings = embed_texts(data[TEXT_COLUMN])
95
+ query_embedding = embed_texts([query])[0]
96
 
97
+ # Compute similarity
98
+ cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
99
+ top_result = torch.topk(cos_scores, k=1)
100
+ idx = top_result.indices[0].item()
101
+ selected = data[idx]
102
 
103
+ # Show Image
104
+ st.image(selected['image'], caption="Most relevant medical image", use_container_width=True)
 
 
105
 
106
+ # Show Text
107
+ st.markdown(f"**Case Description:** {selected[TEXT_COLUMN]}")
108
+
109
+ # GPT Explanation
110
+ if openai.api_key:
111
+ prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}"
112
+ from openai import OpenAI
113
+ client = OpenAI(api_key=openai.api_key)
114
+
115
+ response = client.chat.completions.create(
116
+ model="gpt-4o", # or "gpt-4" if you need the older GPT-4
117
+ messages=[{"role": "user", "content": prompt}],
118
+ temperature=0.5,
119
+ max_tokens=150
120
+ )
121
+ explanation = response.choices[0].message.content
122
 
123
+ st.markdown(f"### πŸ€– Explanation by GPT:\n{explanation}")
124
+ else:
125
+ st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
126
 
127
+ st.caption("This chatbot is for educational purposes only and does not provide medical advice.")