File size: 7,924 Bytes
cd893c4
 
 
 
ec40339
cd893c4
 
 
 
 
 
 
 
 
 
ec40339
 
 
 
cd893c4
 
 
 
 
 
 
eaba0a7
61b0650
 
 
13fea3f
61b0650
ec40339
ce6e551
87b7ab3
3b2a51e
8309783
ec40339
ce6e551
 
8309783
 
 
ec40339
cd893c4
 
 
ec40339
cd893c4
 
 
ec40339
cd893c4
 
 
ec40339
8309783
 
 
 
 
 
3790057
 
cd893c4
 
363c70f
cd893c4
 
8309783
 
ec40339
 
 
 
 
 
 
 
 
 
 
 
 
 
8309783
ec40339
8309783
ec40339
431558f
317bdc9
431558f
 
 
 
9b57975
 
015d1b7
9b57975
015d1b7
9b57975
 
 
 
 
 
 
431558f
4592cd1
431558f
8309783
 
34d476b
d2af51b
8309783
ec40339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34d476b
4592cd1
 
 
8309783
ec40339
8309783
 
ec40339
ce6e551
d6730e3
 
 
 
 
 
ec40339
8309783
d6730e3
8309783
ce6e551
8309783
 
ce6e551
8309783
 
f5df55b
8309783
 
4592cd1
8309783
 
ec40339
4592cd1
f5df55b
ce6e551
 
f5df55b
8309783
 
 
 
ce6e551
 
ec40339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce6e551
ec40339
 
ce6e551
ec40339
ce6e551
ec40339
 
 
 
 
cd893c4
ec40339
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# ================================
#   ✅ Cache-Safe Multimodal App
# ================================

import shutil, os

# ====== Force all cache dirs to /tmp (writable in most environments) ======
CACHE_BASE = "/tmp/cache"
os.environ["HF_HOME"] = f"{CACHE_BASE}/hf_home"
os.environ["TRANSFORMERS_CACHE"] = f"{CACHE_BASE}/transformers"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = f"{CACHE_BASE}/sentence_transformers"
os.environ["HF_DATASETS_CACHE"] = f"{CACHE_BASE}/hf_datasets"
os.environ["TORCH_HOME"] = f"{CACHE_BASE}/torch"
os.environ["STREAMLIT_CACHE_DIR"] = f"{CACHE_BASE}/streamlit_cache"
os.environ["STREAMLIT_STATIC_DIR"] = f"{CACHE_BASE}/streamlit_static"
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"

# Create the directories before imports
os.makedirs(os.environ["STREAMLIT_CONFIG_DIR"], exist_ok=True)

# Create the directories before imports
for path in os.environ.values():
    if path.startswith(CACHE_BASE):
        os.makedirs(path, exist_ok=True)

# ====== Imports ======
import streamlit as st
import torch
from sentence_transformers import SentenceTransformer, util
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset, get_dataset_split_names
from PIL import Image
from openai import OpenAI
import comet_llm
from opik import track

# ========== 🔑 API Key ==========
OpenAI.api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY")
os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
# ========== 📥 Load Models ==========
@st.cache_resource(show_spinner=False)
def load_models():
    _clip_model = CLIPModel.from_pretrained(
        "openai/clip-vit-base-patch32",
        cache_dir=os.environ["TRANSFORMERS_CACHE"]
    )
    _clip_processor = CLIPProcessor.from_pretrained(
        "openai/clip-vit-base-patch32",
        cache_dir=os.environ["TRANSFORMERS_CACHE"]
    )
    _text_model = SentenceTransformer(
        "all-MiniLM-L6-v2",
        cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"]
    )
    return _clip_model, _clip_processor, _text_model

clip_model, clip_processor, text_model = load_models()

# ========== 📥 Load Dataset ==========
@st.cache_resource(show_spinner=False)
def load_medical_data():
    available_splits = get_dataset_split_names("univanxx/3mdbench")
    split_to_use = "train" if "train" in available_splits else available_splits[0]
    dataset = load_dataset(
        "univanxx/3mdbench",
        split=split_to_use,
        cache_dir=os.environ["HF_DATASETS_CACHE"]
    )
    return dataset

# Cache dataset image embeddings (takes time, so cached)
@st.cache_data(show_spinner=True)
def embed_dataset_images(_dataset):
    features = []
    for item in _dataset:
        # Load image from URL/path or raw bytes - adapt this if needed
        img = item["image"]
        inputs_img = clip_processor(images=img, return_tensors="pt")
        with torch.no_grad():
            feat = clip_model.get_image_features(**inputs_img)
        feat /= feat.norm(p=2, dim=-1, keepdim=True)
        features.append(feat.cpu())
    return torch.cat(features, dim=0)

data = load_medical_data()
dataset_image_features = embed_dataset_images(data)

client = OpenAI(api_key=OpenAI.api_key)
# Temporary debug display
#st.write("Dataset columns:", data.features.keys())

# After seeing the real column name, let's say it's "text" instead of "description":
text_field = "text" if "text" in data.features else list(data.features.keys())[0]


@st.cache_data(show_spinner=False)
def prepare_combined_texts(_dataset):
    combined = []
    for gc, c in zip(_dataset["general_complaint"], _dataset["complaints"]):
        gc_str = gc if gc else ""
        c_str = c if c else ""
        combined.append(f"General complaint: {gc_str}. Additional details: {c_str}")
    return combined

combined_texts = prepare_combined_texts(data)

# Then use dynamic access:
#text_embeddings = embed_texts(data[text_field])

# ========== 🧠 Embedding Function ==========
@st.cache_data(show_spinner=False)
def embed_dataset_texts(_texts):
    return text_model.encode(_texts, convert_to_tensor=True)

def embed_query_text(_query):
    return text_model.encode([_query], convert_to_tensor=True)[0]

@track
def get_chat_completion_openai(_client, _prompt: str):
    return _client.chat.completions.create(
        model="gpt-4o",  # or "gpt-4" if you need the older GPT-4
        messages=[{"role": "user", "content": _prompt}],
        temperature=0.5,
        max_tokens=425
    )

@track
def get_similar_prompt(_query):
    text_embeddings = embed_dataset_texts(combined_texts)  # cached
    query_embedding = embed_query_text(_query)  # recalculated each time

    cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
    top_result = torch.topk(cos_scores, k=1)
    _idx = top_result.indices[0].item()
    return data[_idx]


# Pick which text column to use
TEXT_COLUMN = "complaints"  # or "general_complaint", depending on your needs

# ========== 🧑‍⚕️ App UI ==========
st.title("🩺 Multimodal Medical Chatbot")

query = st.text_input("Enter your medical question or symptom description:")
uploaded_files = st.file_uploader("Upload an image to find similar medical cases:", type=["png", "jpg", "jpeg"], accept_multiple_files=True)

# Add author info in the sidebar
with st.sidebar:
    st.markdown("## 👤👤Authors")
    st.markdown("**Vasan Iyer**")
    st.markdown("**Eric J Giacomucci**")
    st.markdown("[GitHub](https://github.com/Vaiy108)")
    st.markdown("[LinkedIn](https://linkedin.com/in/vasan-iyer)")

if st.button("Submit") and query:
    with st.spinner("Searching medical cases..."):


        # Compute similarity
        selected = get_similar_prompt(query)

        # Show Image
        st.image(selected['image'], caption="Most relevant medical image", use_container_width=True)

        # Show Text
        st.markdown(f"**Case Description:** {selected[TEXT_COLUMN]}")

        # GPT Explanation
        if OpenAI.api_key:
            prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}"

            explanation = get_chat_completion_openai(client, prompt)
            explanation = explanation.choices[0].message.content

            st.markdown(f"### 🤖 Explanation by GPT:\n{explanation}")
        else:
            st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")



if uploaded_files is not None:
    with st.spinner("Searching medical cases..."):
        st.write(f"Number of files: {len(uploaded_files)}")

        if len(uploaded_files) > 0:
            print(uploaded_files)
            uploaded_file = uploaded_files[0]
            st.write(f'uploading file {uploaded_file.name}')
            query_image = Image.open(uploaded_file).convert("RGB")
            st.image(query_image, caption="Your uploaded image", use_container_width=True)

            # Embed uploaded image
            inputs = clip_processor(images=query_image, return_tensors="pt")
            with torch.no_grad():
                query_feat = clip_model.get_image_features(**inputs)
            query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True)

            # Compute cosine similarity
            similarities = (dataset_image_features @ query_feat.T).squeeze(1)  # [num_dataset_images]

            top_k = 3
            top_results = torch.topk(similarities, k=top_k)

            st.write(f"Top {top_k} similar medical cases:")

            for rank, idx in enumerate(top_results.indices):
                score = top_results.values[rank].item()
                similar_img = data[int(idx)]['image']
                st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True)
                st.markdown(f"**Case description:** {data[int(idx)]['complaints']}")

st.caption("This chatbot is for educational purposes only and does not provide medical advice.")