File size: 7,923 Bytes
0a89a37
a82bbd1
0a89a37
 
317f15f
e146235
0a89a37
 
 
 
 
 
 
 
 
317f15f
 
 
 
0a89a37
 
 
a82bbd1
0a89a37
 
a82bbd1
0a89a37
e146235
 
0a89a37
 
 
c8b7285
a82bbd1
 
 
0a89a37
c8b7285
97f12dc
 
0a89a37
 
 
b9bdf95
0a89a37
 
 
b9bdf95
0a89a37
 
 
b9bdf95
0a89a37
 
 
b9bdf95
4a259f2
0a89a37
 
 
 
 
 
 
 
 
 
 
 
 
e146235
6f5e256
 
 
 
 
 
 
b9bdf95
6f5e256
b9bdf95
6f5e256
 
 
 
0a89a37
6f5e256
a82bbd1
c8b7285
a82bbd1
 
 
 
 
 
61e5bfd
 
 
 
 
 
 
 
 
 
 
 
a82bbd1
 
e146235
0a89a37
a82bbd1
61e5bfd
a82bbd1
 
b9bdf95
 
61e5bfd
97f12dc
b9bdf95
 
a82bbd1
b9bdf95
97f12dc
3314fb6
97f12dc
 
dc78eca
b9bdf95
9ea931a
b9bdf95
9ea931a
1230889
 
b9bdf95
 
1230889
234ca8e
c8b7285
 
 
 
 
 
 
a9986c4
 
 
 
 
 
 
 
 
1230889
a9986c4
317f15f
 
 
 
 
 
 
 
 
 
 
 
 
c8b7285
317f15f
 
 
 
 
 
 
 
 
319c9f8
 
3320c6e
855ea3c
5e132cf
855ea3c
319c9f8
 
 
 
 
 
 
 
 
 
 
 
855ea3c
319c9f8
 
855ea3c
319c9f8
 
855ea3c
319c9f8
855ea3c
319c9f8
 
 
 
 
234ca8e
a82bbd1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# ================================
#   βœ… Cache-Safe Multimodal App
# ================================

import shutil, os

# ====== Force all cache dirs to /tmp (writable in most environments) ======
CACHE_BASE = "/tmp/cache"
os.environ["HF_HOME"] = f"{CACHE_BASE}/hf_home"
os.environ["TRANSFORMERS_CACHE"] = f"{CACHE_BASE}/transformers"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = f"{CACHE_BASE}/sentence_transformers"
os.environ["HF_DATASETS_CACHE"] = f"{CACHE_BASE}/hf_datasets"
os.environ["TORCH_HOME"] = f"{CACHE_BASE}/torch"
os.environ["STREAMLIT_CACHE_DIR"] = f"{CACHE_BASE}/streamlit_cache"
os.environ["STREAMLIT_STATIC_DIR"] = f"{CACHE_BASE}/streamlit_static"
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"

# Create the directories before imports
os.makedirs(os.environ["STREAMLIT_CONFIG_DIR"], exist_ok=True)

# Create the directories before imports
for path in os.environ.values():
    if path.startswith(CACHE_BASE):
        os.makedirs(path, exist_ok=True)

# ====== Imports ======
import streamlit as st
import torch
from sentence_transformers import SentenceTransformer, util
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset, get_dataset_split_names
from PIL import Image
from openai import OpenAI
import comet_llm
from opik import track

# ========== πŸ”‘ API Key ==========
OpenAI.api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY")
os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
# ========== πŸ“₯ Load Models ==========
@st.cache_resource(show_spinner=False)
def load_models():
    _clip_model = CLIPModel.from_pretrained(
        "openai/clip-vit-base-patch32",
        cache_dir=os.environ["TRANSFORMERS_CACHE"]
    )
    _clip_processor = CLIPProcessor.from_pretrained(
        "openai/clip-vit-base-patch32",
        cache_dir=os.environ["TRANSFORMERS_CACHE"]
    )
    _text_model = SentenceTransformer(
        "all-MiniLM-L6-v2",
        cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"]
    )
    return _clip_model, _clip_processor, _text_model

clip_model, clip_processor, text_model = load_models()

# ========== πŸ“₯ Load Dataset ==========
@st.cache_resource(show_spinner=False)
def load_medical_data():
    available_splits = get_dataset_split_names("univanxx/3mdbench")
    split_to_use = "train" if "train" in available_splits else available_splits[0]
    dataset = load_dataset(
        "univanxx/3mdbench",
        split=split_to_use,
        cache_dir=os.environ["HF_DATASETS_CACHE"]
    )
    return dataset

# Cache dataset image embeddings (takes time, so cached)
@st.cache_data(show_spinner=True)
def embed_dataset_images(_dataset):
    features = []
    for item in _dataset:
        # Load image from URL/path or raw bytes - adapt this if needed
        img = item["image"]
        inputs_img = clip_processor(images=img, return_tensors="pt")
        with torch.no_grad():
            feat = clip_model.get_image_features(**inputs_img)
        feat /= feat.norm(p=2, dim=-1, keepdim=True)
        features.append(feat.cpu())
    return torch.cat(features, dim=0)

data = load_medical_data()
dataset_image_features = embed_dataset_images(data)

client = OpenAI(api_key=OpenAI.api_key)
# Temporary debug display
#st.write("Dataset columns:", data.features.keys())

# After seeing the real column name, let's say it's "text" instead of "description":
text_field = "text" if "text" in data.features else list(data.features.keys())[0]


@st.cache_data(show_spinner=False)
def prepare_combined_texts(_dataset):
    combined = []
    for gc, c in zip(_dataset["general_complaint"], _dataset["complaints"]):
        gc_str = gc if gc else ""
        c_str = c if c else ""
        combined.append(f"General complaint: {gc_str}. Additional details: {c_str}")
    return combined

combined_texts = prepare_combined_texts(data)

# Then use dynamic access:
#text_embeddings = embed_texts(data[text_field])

# ========== 🧠 Embedding Function ==========
@st.cache_data(show_spinner=False)
def embed_dataset_texts(_texts):
    return text_model.encode(_texts, convert_to_tensor=True)

def embed_query_text(_query):
    return text_model.encode([_query], convert_to_tensor=True)[0]

@track
def get_chat_completion_openai(_client, _prompt: str):
    return _client.chat.completions.create(
        model="gpt-4o",  # or "gpt-4" if you need the older GPT-4
        messages=[{"role": "user", "content": _prompt}],
        temperature=0.5,
        max_tokens=425
    )

@track
def get_similar_prompt(_query):
    text_embeddings = embed_dataset_texts(combined_texts)  # cached
    query_embedding = embed_query_text(_query)  # recalculated each time

    cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
    top_result = torch.topk(cos_scores, k=1)
    _idx = top_result.indices[0].item()
    return data[_idx]


# Pick which text column to use
TEXT_COLUMN = "complaints"  # or "general_complaint", depending on your needs

# ========== πŸ§‘β€βš•οΈ App UI ==========
st.title("🩺 Multimodal Medical Chatbot")

query = st.text_input("Enter your medical question or symptom description:")
uploaded_files = st.file_uploader("Upload an image to find similar medical cases:", type=["png", "jpg", "jpeg"], accept_multiple_files=True)

# Add author info in the sidebar
with st.sidebar:
    st.markdown("## πŸ‘€πŸ‘€Authors")
    st.markdown("**Vasan Iyer**")
    st.markdown("**Eric J Giacomucci**")
    st.markdown("[GitHub](https://github.com/Vaiy108)")
    st.markdown("[LinkedIn](https://linkedin.com/in/vasan-iyer)")

if st.button("Submit") and query:
    with st.spinner("Searching medical cases..."):


        # Compute similarity
        selected = get_similar_prompt(query)

        # Show Image
        st.image(selected['image'], caption="Most relevant medical image", use_container_width=True)

        # Show Text
        st.markdown(f"**Case Description:** {selected[TEXT_COLUMN]}")

        # GPT Explanation
        if OpenAI.api_key:
            prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}"

            explanation = get_chat_completion_openai(client, prompt)
            explanation = explanation.choices[0].message.content

            st.markdown(f"### πŸ€– Explanation by GPT:\n{explanation}")
        else:
            st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")



if uploaded_files is not None:
    with st.spinner("Searching medical cases..."):
        st.write(f"Number of files: {len(uploaded_files)}")

        if len(uploaded_files) > 0:
            print(uploaded_files)
            uploaded_file = uploaded_files[0]
            st.write(f'uploading file {uploaded_file.name}')
            query_image = Image.open(uploaded_file).convert("RGB")
            st.image(query_image, caption="Your uploaded image", use_container_width=True)

            # Embed uploaded image
            inputs = clip_processor(images=query_image, return_tensors="pt")
            with torch.no_grad():
                query_feat = clip_model.get_image_features(**inputs)
            query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True)

            # Compute cosine similarity
            similarities = (dataset_image_features @ query_feat.T).squeeze(1)  # [num_dataset_images]

            top_k = 3
            top_results = torch.topk(similarities, k=top_k)

            st.write(f"Top {top_k} similar medical cases:")

            for rank, idx in enumerate(top_results.indices):
                score = top_results.values[rank].item()
                similar_img = data[int(idx)]['image']
                st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True)
                st.markdown(f"**Case description:** {data[int(idx)]['complaints']}")

st.caption("This chatbot is for educational purposes only and does not provide medical advice.")