jocko commited on
Commit
a82bbd1
Β·
1 Parent(s): dafea1a
Files changed (1) hide show
  1. src/streamlit_app.py +59 -84
src/streamlit_app.py CHANGED
@@ -1,15 +1,9 @@
1
  # ================================
2
- # βœ… Cache-Safe Multimodal App with Full Opik Tracking
3
  # ================================
4
 
5
  import os
6
 
7
- # ---- Disable Comet auto-patching (MUST be set BEFORE importing openai/comet_llm/comet_ml) ----
8
- # Disable all Comet auto-logging / monkey-patching
9
- os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
10
- # Optionally: only disable LLM auto-logging
11
- os.environ["COMET_DISABLE_AUTO_LOGGING_LLM"] = "1"
12
-
13
  # ====== Force all cache dirs to /tmp (writable in most environments) ======
14
  CACHE_BASE = "/tmp/cache"
15
  os.environ["HF_HOME"] = f"{CACHE_BASE}/hf_home"
@@ -22,29 +16,26 @@ os.environ["STREAMLIT_STATIC_DIR"] = f"{CACHE_BASE}/streamlit_static"
22
 
23
  # Create the directories before imports
24
  for path in os.environ.values():
25
- if isinstance(path, str) and path.startswith(CACHE_BASE):
26
  os.makedirs(path, exist_ok=True)
27
 
28
- # ====== Now safe to import libraries ======
29
  import streamlit as st
30
  import torch
31
  from sentence_transformers import SentenceTransformer, util
32
  from transformers import CLIPProcessor, CLIPModel
33
  from datasets import load_dataset, get_dataset_split_names
34
  from PIL import Image
35
-
36
  import openai
37
- from openai import OpenAI # OK to import after openai is present
38
- from opik import track, log_event
 
39
 
40
 
41
  # ========== πŸ”‘ API Key ==========
42
  openai.api_key = os.getenv("OPENAI_API_KEY")
43
  os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY")
44
  os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
45
-
46
- client = OpenAI(api_key=openai.api_key)
47
-
48
  # ========== πŸ“₯ Load Models ==========
49
  @st.cache_resource(show_spinner=False)
50
  def load_models():
@@ -77,83 +68,67 @@ def load_medical_data():
77
  return dataset
78
 
79
  data = load_medical_data()
80
- TEXT_COLUMN = "complaints" if "complaints" in data.features else list(data.features.keys())[0]
 
 
 
 
 
 
 
 
 
 
81
 
82
  # ========== 🧠 Embedding Function ==========
 
 
 
 
 
 
 
 
 
 
 
 
83
  @track
84
- def embed_texts_tracked(texts, model_name="all-MiniLM-L6-v2"):
85
- embeddings = text_model.encode(texts, convert_to_tensor=True)
86
- log_event("embedding_generated", {
87
- "model": model_name,
88
- "num_texts": len(texts),
89
- "embedding_shape": list(embeddings.shape)
90
- })
91
- return embeddings
92
-
93
- # ========== πŸ” Case Selection ==========
94
- @track
95
- def select_top_case(query_embedding, text_embeddings, k=1):
96
- cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
97
- top_result = torch.topk(cos_scores, k=k)
98
- idx = top_result.indices[0].item()
99
- score = float(top_result.values[0].item())
100
- log_event("case_selected", {
101
- "case_index": idx,
102
- "similarity_score": score
103
- })
104
- return idx, score
105
-
106
- # ========== πŸ–ΌοΈ Display Case ==========
107
- @track
108
- def display_case(case):
109
- st.image(case['image'], caption="Most relevant medical image", use_container_width=True)
110
- st.markdown(f"**Case Description:** {case[TEXT_COLUMN]}")
111
- log_event("case_displayed", {
112
- "case_id": case.get("id", None),
113
- "description_preview": case[TEXT_COLUMN][:100] + "..."
114
- })
115
- return case
116
-
117
- # ========== πŸ€– GPT Completion ==========
118
- @track
119
- def get_chat_completion_openai(client, prompt: str, case_id=None):
120
- response = client.chat.completions.create(
121
- model="gpt-4o",
122
  messages=[{"role": "user", "content": prompt}],
123
  temperature=0.5,
124
  max_tokens=150
125
  )
126
- answer = response.choices[0].message.content
127
- log_event("gpt_response", {
128
- "case_id": case_id,
129
- "prompt_length": len(prompt),
130
- "response_length": len(answer)
131
- })
132
- return answer
133
-
134
- # ========== πŸ”„ Full Query Processing ==========
135
- @track
136
- def process_query(query):
137
- text_embeddings = embed_texts_tracked(data[TEXT_COLUMN])
138
- query_embedding = embed_texts_tracked([query])[0]
139
- idx, score = select_top_case(query_embedding, text_embeddings)
140
- case = display_case(data[idx])
141
- explanation = get_chat_completion_openai(client, f"Explain this case in plain English: {case[TEXT_COLUMN]}", case_id=idx)
142
- return {
143
- "query": query,
144
- "case_id": idx,
145
- "similarity_score": score,
146
- "gpt_explanation": explanation
147
- }
148
-
149
- # ========== πŸ–₯️ Streamlit UI ==========
150
- st.title("🩺 Multimodal Medical Chatbot")
151
 
152
- query = st.text_input("Enter your medical question or symptom description:")
153
 
154
  if query:
155
- with st.spinner("Processing your query..."):
156
- session_data = process_query(query)
157
- st.markdown(f"### πŸ€– Explanation by GPT:\n{session_data['gpt_explanation']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- st.caption("This chatbot is for educational purposes only and does not provide medical advice.")
 
1
  # ================================
2
+ # βœ… Cache-Safe Multimodal App
3
  # ================================
4
 
5
  import os
6
 
 
 
 
 
 
 
7
  # ====== Force all cache dirs to /tmp (writable in most environments) ======
8
  CACHE_BASE = "/tmp/cache"
9
  os.environ["HF_HOME"] = f"{CACHE_BASE}/hf_home"
 
16
 
17
  # Create the directories before imports
18
  for path in os.environ.values():
19
+ if path.startswith(CACHE_BASE):
20
  os.makedirs(path, exist_ok=True)
21
 
22
+ # ====== Imports ======
23
  import streamlit as st
24
  import torch
25
  from sentence_transformers import SentenceTransformer, util
26
  from transformers import CLIPProcessor, CLIPModel
27
  from datasets import load_dataset, get_dataset_split_names
28
  from PIL import Image
 
29
  import openai
30
+ import comet_llm
31
+ from opik import track
32
+
33
 
34
 
35
  # ========== πŸ”‘ API Key ==========
36
  openai.api_key = os.getenv("OPENAI_API_KEY")
37
  os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY")
38
  os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
 
 
 
39
  # ========== πŸ“₯ Load Models ==========
40
  @st.cache_resource(show_spinner=False)
41
  def load_models():
 
68
  return dataset
69
 
70
  data = load_medical_data()
71
+
72
+ from openai import OpenAI
73
+ client = OpenAI(api_key=openai.api_key)
74
+ # Temporary debug display
75
+ #st.write("Dataset columns:", data.features.keys())
76
+
77
+ # After seeing the real column name, let's say it's "text" instead of "description":
78
+ text_field = "text" if "text" in data.features else list(data.features.keys())[0]
79
+
80
+ # Then use dynamic access:
81
+ #text_embeddings = embed_texts(data[text_field])
82
 
83
  # ========== 🧠 Embedding Function ==========
84
+ @st.cache_data(show_spinner=False)
85
+ def embed_texts(_texts):
86
+ return text_model.encode(_texts, convert_to_tensor=True)
87
+
88
+ # Pick which text column to use
89
+ TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs
90
+
91
+ # ========== πŸ§‘β€βš•οΈ App UI ==========
92
+ st.title("🩺 Multimodal Medical Chatbot")
93
+
94
+ query = st.text_input("Enter your medical question or symptom description:")
95
+
96
  @track
97
+ def get_chat_completion_openai(client, prompt: str):
98
+ return client.chat.completions.create(
99
+ model="gpt-4o", # or "gpt-4" if you need the older GPT-4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  messages=[{"role": "user", "content": prompt}],
101
  temperature=0.5,
102
  max_tokens=150
103
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
 
105
 
106
  if query:
107
+ with st.spinner("Searching medical cases..."):
108
+ text_embeddings = embed_texts(data[TEXT_COLUMN])
109
+ query_embedding = embed_texts([query])[0]
110
+
111
+ # Compute similarity
112
+ cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
113
+ top_result = torch.topk(cos_scores, k=1)
114
+ idx = top_result.indices[0].item()
115
+ selected = data[idx]
116
+
117
+ # Show Image
118
+ st.image(selected['image'], caption="Most relevant medical image", use_container_width=True)
119
+
120
+ # Show Text
121
+ st.markdown(f"**Case Description:** {selected[TEXT_COLUMN]}")
122
+
123
+ # GPT Explanation
124
+ if openai.api_key:
125
+ prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}"
126
+
127
+ explanation = get_chat_completion_openai(client, prompt)
128
+ explanation = explanation.choices[0].message.content
129
+
130
+ st.markdown(f"### πŸ€– Explanation by GPT:\n{explanation}")
131
+ else:
132
+ st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
133
 
134
+ st.caption("This chatbot is for educational purposes only and does not provide medical advice.")