vi108 commited on
Commit
ce6e551
·
verified ·
1 Parent(s): 015d1b7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +72 -22
src/streamlit_app.py CHANGED
@@ -27,10 +27,15 @@ from transformers import CLIPProcessor, CLIPModel
27
  from datasets import load_dataset, get_dataset_split_names
28
  from PIL import Image
29
  import openai
 
 
 
 
30
 
31
  # ========== 🔑 API Key ==========
32
  openai.api_key = os.getenv("OPENAI_API_KEY")
33
-
 
34
  # ========== 📥 Load Models ==========
35
  @st.cache_resource(show_spinner=False)
36
  def load_models():
@@ -52,9 +57,6 @@ clip_model, clip_processor, text_model = load_models()
52
 
53
  # ========== 📥 Load Dataset ==========
54
  @st.cache_resource(show_spinner=False)
55
-
56
-
57
-
58
  def load_medical_data():
59
  available_splits = get_dataset_split_names("univanxx/3mdbench")
60
  split_to_use = "train" if "train" in available_splits else available_splits[0]
@@ -67,6 +69,8 @@ def load_medical_data():
67
 
68
  data = load_medical_data()
69
 
 
 
70
  # Temporary debug display
71
  #st.write("Dataset columns:", data.features.keys())
72
 
@@ -103,19 +107,49 @@ TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs
103
  st.title("🩺 Multimodal Medical Chatbot")
104
 
105
  query = st.text_input("Enter your medical question or symptom description:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  if query:
108
  with st.spinner("Searching medical cases..."):
109
- text_embeddings = embed_dataset_texts(combined_texts) # cached
110
- query_embedding = embed_query_text(query) # recalculated each time
111
- # text_embeddings = embed_dataset_texts(data[TEXT_COLUMN])
112
- # query_embedding = embed_query_text([query])[0]
113
 
114
  # Compute similarity
115
- cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
116
- top_result = torch.topk(cos_scores, k=1)
117
- idx = top_result.indices[0].item()
118
- selected = data[idx]
119
 
120
  # Show Image
121
  st.image(selected['image'], caption="Most relevant medical image", use_container_width=True)
@@ -126,20 +160,36 @@ if query:
126
  # GPT Explanation
127
  if openai.api_key:
128
  prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}"
129
- from openai import OpenAI
130
- client = OpenAI(api_key=openai.api_key)
131
 
132
- response = client.chat.completions.create(
133
- model="gpt-4o", # or "gpt-4" if you need the older GPT-4
134
- messages=[{"role": "user", "content": prompt}],
135
- temperature=0.5,
136
- max_tokens=150
137
- )
138
- explanation = response.choices[0].message.content
139
 
140
  st.markdown(f"### 🤖 Explanation by GPT:\n{explanation}")
141
  else:
142
  st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
143
 
144
- st.caption("This chatbot is for educational purposes only and does not provide medical advice.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
 
 
27
  from datasets import load_dataset, get_dataset_split_names
28
  from PIL import Image
29
  import openai
30
+ import comet_llm
31
+ from opik import track
32
+
33
+
34
 
35
  # ========== 🔑 API Key ==========
36
  openai.api_key = os.getenv("OPENAI_API_KEY")
37
+ os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY")
38
+ os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
39
  # ========== 📥 Load Models ==========
40
  @st.cache_resource(show_spinner=False)
41
  def load_models():
 
57
 
58
  # ========== 📥 Load Dataset ==========
59
  @st.cache_resource(show_spinner=False)
 
 
 
60
  def load_medical_data():
61
  available_splits = get_dataset_split_names("univanxx/3mdbench")
62
  split_to_use = "train" if "train" in available_splits else available_splits[0]
 
69
 
70
  data = load_medical_data()
71
 
72
+ from openai import OpenAI
73
+ client = OpenAI(api_key=openai.api_key)
74
  # Temporary debug display
75
  #st.write("Dataset columns:", data.features.keys())
76
 
 
107
  st.title("🩺 Multimodal Medical Chatbot")
108
 
109
  query = st.text_input("Enter your medical question or symptom description:")
110
+ uploaded_file = st.file_uploader("Upload an image to find similar medical cases:", type=["png", "jpg", "jpeg"])
111
+
112
+ @track
113
+ def get_chat_completion_openai(client, prompt: str):
114
+ return client.chat.completions.create(
115
+ model="gpt-4o", # or "gpt-4" if you need the older GPT-4
116
+ messages=[{"role": "user", "content": prompt}],
117
+ temperature=0.5,
118
+ max_tokens=150
119
+ )
120
+
121
+ @track
122
+ def get_similar_prompt(query):
123
+ text_embeddings = embed_dataset_texts(combined_texts) # cached
124
+ query_embedding = embed_query_text(query) # recalculated each time
125
+
126
+ cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
127
+ top_result = torch.topk(cos_scores, k=1)
128
+ idx = top_result.indices[0].item()
129
+ return data[idx]
130
+
131
+ # Cache dataset image embeddings (takes time, so cached)
132
+ @st.cache_data(show_spinner=True)
133
+ def embed_dataset_images(_dataset):
134
+ features = []
135
+ for item in _dataset:
136
+ # Load image from URL/path or raw bytes - adapt this if needed
137
+ img = item["image"]
138
+ inputs = clip_processor(images=img, return_tensors="pt")
139
+ with torch.no_grad():
140
+ feat = clip_model.get_image_features(**inputs)
141
+ feat /= feat.norm(p=2, dim=-1, keepdim=True)
142
+ features.append(feat.cpu())
143
+ return torch.cat(features, dim=0)
144
+
145
+ dataset_image_features = embed_dataset_images(data)
146
 
147
  if query:
148
  with st.spinner("Searching medical cases..."):
149
+
 
 
 
150
 
151
  # Compute similarity
152
+ selected = get_similar_prompt(query)
 
 
 
153
 
154
  # Show Image
155
  st.image(selected['image'], caption="Most relevant medical image", use_container_width=True)
 
160
  # GPT Explanation
161
  if openai.api_key:
162
  prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}"
 
 
163
 
164
+ explanation = get_chat_completion_openai(client, prompt)
165
+ explanation = explanation.choices[0].message.content
 
 
 
 
 
166
 
167
  st.markdown(f"### 🤖 Explanation by GPT:\n{explanation}")
168
  else:
169
  st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
170
 
171
+ if uploaded_file is not None:
172
+ query_image = Image.open(uploaded_file).convert("RGB")
173
+ st.image(query_image, caption="Your uploaded image", use_container_width=True)
174
+
175
+ # Embed uploaded image
176
+ inputs = clip_processor(images=query_image, return_tensors="pt")
177
+ with torch.no_grad():
178
+ query_feat = clip_model.get_image_features(**inputs)
179
+ query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True)
180
+
181
+ # Compute cosine similarity
182
+ similarities = (dataset_image_features @ query_feat.T).squeeze(1) # [num_dataset_images]
183
+
184
+ top_k = 3
185
+ top_results = torch.topk(similarities, k=top_k)
186
+
187
+ st.write(f"Top {top_k} similar medical cases:")
188
+
189
+ for rank, idx in enumerate(top_results.indices):
190
+ score = top_results.values[rank].item()
191
+ similar_img = data[int(idx)]['image']
192
+ st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True)
193
+ st.markdown(f"**Case description:** {data[int(idx)]['complaints']}")
194
 
195
+ st.caption("This chatbot is for educational purposes only and does not provide medical advice.")