jocko commited on
Commit
584fdfe
Β·
1 Parent(s): 0fcedb1

add comet on all other operations

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +75 -61
src/streamlit_app.py CHANGED
@@ -1,5 +1,5 @@
1
  # ================================
2
- # βœ… Cache-Safe Multimodal App
3
  # ================================
4
 
5
  import os
@@ -27,15 +27,16 @@ from transformers import CLIPProcessor, CLIPModel
27
  from datasets import load_dataset, get_dataset_split_names
28
  from PIL import Image
29
  import openai
30
- import comet_llm
31
- from opik import track
32
-
33
-
34
 
35
  # ========== πŸ”‘ API Key ==========
36
  openai.api_key = os.getenv("OPENAI_API_KEY")
37
  os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY")
38
  os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
 
 
 
39
  # ========== πŸ“₯ Load Models ==========
40
  @st.cache_resource(show_spinner=False)
41
  def load_models():
@@ -57,9 +58,6 @@ clip_model, clip_processor, text_model = load_models()
57
 
58
  # ========== πŸ“₯ Load Dataset ==========
59
  @st.cache_resource(show_spinner=False)
60
-
61
-
62
-
63
  def load_medical_data():
64
  available_splits = get_dataset_split_names("univanxx/3mdbench")
65
  split_to_use = "train" if "train" in available_splits else available_splits[0]
@@ -71,67 +69,83 @@ def load_medical_data():
71
  return dataset
72
 
73
  data = load_medical_data()
74
-
75
- from openai import OpenAI
76
- client = OpenAI(api_key=openai.api_key)
77
- # Temporary debug display
78
- #st.write("Dataset columns:", data.features.keys())
79
-
80
- # After seeing the real column name, let's say it's "text" instead of "description":
81
- text_field = "text" if "text" in data.features else list(data.features.keys())[0]
82
-
83
- # Then use dynamic access:
84
- #text_embeddings = embed_texts(data[text_field])
85
 
86
  # ========== 🧠 Embedding Function ==========
87
- @st.cache_data(show_spinner=False)
88
- def embed_texts(_texts):
89
- return text_model.encode(_texts, convert_to_tensor=True)
90
-
91
- # Pick which text column to use
92
- TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs
93
-
94
- # ========== πŸ§‘β€βš•οΈ App UI ==========
95
- st.title("🩺 Multimodal Medical Chatbot")
96
-
97
- query = st.text_input("Enter your medical question or symptom description:")
98
-
99
  @track
100
- def get_chat_completion_openai(client, prompt: str):
101
- return client.chat.completions.create(
102
- model="gpt-4o", # or "gpt-4" if you need the older GPT-4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  messages=[{"role": "user", "content": prompt}],
104
  temperature=0.5,
105
  max_tokens=150
106
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
 
108
 
109
  if query:
110
- with st.spinner("Searching medical cases..."):
111
- text_embeddings = embed_texts(data[TEXT_COLUMN])
112
- query_embedding = embed_texts([query])[0]
113
-
114
- # Compute similarity
115
- cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
116
- top_result = torch.topk(cos_scores, k=1)
117
- idx = top_result.indices[0].item()
118
- selected = data[idx]
119
-
120
- # Show Image
121
- st.image(selected['image'], caption="Most relevant medical image", use_container_width=True)
122
-
123
- # Show Text
124
- st.markdown(f"**Case Description:** {selected[TEXT_COLUMN]}")
125
-
126
- # GPT Explanation
127
- if openai.api_key:
128
- prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}"
129
-
130
- explanation = get_chat_completion_openai(client, prompt)
131
- explanation = explanation.choices[0].message.content
132
-
133
- st.markdown(f"### πŸ€– Explanation by GPT:\n{explanation}")
134
- else:
135
- st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
136
 
137
- st.caption("This chatbot is for educational purposes only and does not provide medical advice.")
 
1
  # ================================
2
+ # βœ… Cache-Safe Multimodal App with Full Opik Tracking
3
  # ================================
4
 
5
  import os
 
27
  from datasets import load_dataset, get_dataset_split_names
28
  from PIL import Image
29
  import openai
30
+ from opik import track, log_event
31
+ from openai import OpenAI
 
 
32
 
33
  # ========== πŸ”‘ API Key ==========
34
  openai.api_key = os.getenv("OPENAI_API_KEY")
35
  os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY")
36
  os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
37
+
38
+ client = OpenAI(api_key=openai.api_key)
39
+
40
  # ========== πŸ“₯ Load Models ==========
41
  @st.cache_resource(show_spinner=False)
42
  def load_models():
 
58
 
59
  # ========== πŸ“₯ Load Dataset ==========
60
  @st.cache_resource(show_spinner=False)
 
 
 
61
  def load_medical_data():
62
  available_splits = get_dataset_split_names("univanxx/3mdbench")
63
  split_to_use = "train" if "train" in available_splits else available_splits[0]
 
69
  return dataset
70
 
71
  data = load_medical_data()
72
+ TEXT_COLUMN = "complaints" if "complaints" in data.features else list(data.features.keys())[0]
 
 
 
 
 
 
 
 
 
 
73
 
74
  # ========== 🧠 Embedding Function ==========
 
 
 
 
 
 
 
 
 
 
 
 
75
  @track
76
+ def embed_texts_tracked(texts, model_name="all-MiniLM-L6-v2"):
77
+ embeddings = text_model.encode(texts, convert_to_tensor=True)
78
+ log_event("embedding_generated", {
79
+ "model": model_name,
80
+ "num_texts": len(texts),
81
+ "embedding_shape": list(embeddings.shape)
82
+ })
83
+ return embeddings
84
+
85
+ # ========== πŸ” Case Selection ==========
86
+ @track
87
+ def select_top_case(query_embedding, text_embeddings, k=1):
88
+ cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
89
+ top_result = torch.topk(cos_scores, k=k)
90
+ idx = top_result.indices[0].item()
91
+ score = float(top_result.values[0].item())
92
+ log_event("case_selected", {
93
+ "case_index": idx,
94
+ "similarity_score": score
95
+ })
96
+ return idx, score
97
+
98
+ # ========== πŸ–ΌοΈ Display Case ==========
99
+ @track
100
+ def display_case(case):
101
+ st.image(case['image'], caption="Most relevant medical image", use_container_width=True)
102
+ st.markdown(f"**Case Description:** {case[TEXT_COLUMN]}")
103
+ log_event("case_displayed", {
104
+ "case_id": case.get("id", None),
105
+ "description_preview": case[TEXT_COLUMN][:100] + "..."
106
+ })
107
+ return case
108
+
109
+ # ========== πŸ€– GPT Completion ==========
110
+ @track
111
+ def get_chat_completion_openai(client, prompt: str, case_id=None):
112
+ response = client.chat.completions.create(
113
+ model="gpt-4o",
114
  messages=[{"role": "user", "content": prompt}],
115
  temperature=0.5,
116
  max_tokens=150
117
  )
118
+ answer = response.choices[0].message.content
119
+ log_event("gpt_response", {
120
+ "case_id": case_id,
121
+ "prompt_length": len(prompt),
122
+ "response_length": len(answer)
123
+ })
124
+ return answer
125
+
126
+ # ========== πŸ”„ Full Query Processing ==========
127
+ @track
128
+ def process_query(query):
129
+ text_embeddings = embed_texts_tracked(data[TEXT_COLUMN])
130
+ query_embedding = embed_texts_tracked([query])[0]
131
+ idx, score = select_top_case(query_embedding, text_embeddings)
132
+ case = display_case(data[idx])
133
+ explanation = get_chat_completion_openai(client, f"Explain this case in plain English: {case[TEXT_COLUMN]}", case_id=idx)
134
+ return {
135
+ "query": query,
136
+ "case_id": idx,
137
+ "similarity_score": score,
138
+ "gpt_explanation": explanation
139
+ }
140
+
141
+ # ========== πŸ–₯️ Streamlit UI ==========
142
+ st.title("🩺 Multimodal Medical Chatbot")
143
 
144
+ query = st.text_input("Enter your medical question or symptom description:")
145
 
146
  if query:
147
+ with st.spinner("Processing your query..."):
148
+ session_data = process_query(query)
149
+ st.markdown(f"### πŸ€– Explanation by GPT:\n{session_data['gpt_explanation']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ st.caption("This chatbot is for educational purposes only and does not provide medical advice.")