Spaces:
Running
Running
File size: 7,924 Bytes
cd893c4 ec40339 cd893c4 ec40339 cd893c4 eaba0a7 61b0650 13fea3f 61b0650 ec40339 ce6e551 87b7ab3 3b2a51e 8309783 ec40339 ce6e551 8309783 ec40339 cd893c4 ec40339 cd893c4 ec40339 cd893c4 ec40339 8309783 3790057 cd893c4 363c70f cd893c4 8309783 ec40339 8309783 ec40339 8309783 ec40339 431558f 317bdc9 431558f 9b57975 015d1b7 9b57975 015d1b7 9b57975 431558f 4592cd1 431558f 8309783 34d476b d2af51b 8309783 ec40339 34d476b 4592cd1 8309783 ec40339 8309783 ec40339 ce6e551 d6730e3 ec40339 8309783 d6730e3 8309783 ce6e551 8309783 ce6e551 8309783 f5df55b 8309783 4592cd1 8309783 ec40339 4592cd1 f5df55b ce6e551 f5df55b 8309783 ce6e551 ec40339 ce6e551 ec40339 ce6e551 ec40339 ce6e551 ec40339 cd893c4 ec40339 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
# ================================
# ✅ Cache-Safe Multimodal App
# ================================
import shutil, os
# ====== Force all cache dirs to /tmp (writable in most environments) ======
CACHE_BASE = "/tmp/cache"
os.environ["HF_HOME"] = f"{CACHE_BASE}/hf_home"
os.environ["TRANSFORMERS_CACHE"] = f"{CACHE_BASE}/transformers"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = f"{CACHE_BASE}/sentence_transformers"
os.environ["HF_DATASETS_CACHE"] = f"{CACHE_BASE}/hf_datasets"
os.environ["TORCH_HOME"] = f"{CACHE_BASE}/torch"
os.environ["STREAMLIT_CACHE_DIR"] = f"{CACHE_BASE}/streamlit_cache"
os.environ["STREAMLIT_STATIC_DIR"] = f"{CACHE_BASE}/streamlit_static"
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"
# Create the directories before imports
os.makedirs(os.environ["STREAMLIT_CONFIG_DIR"], exist_ok=True)
# Create the directories before imports
for path in os.environ.values():
if path.startswith(CACHE_BASE):
os.makedirs(path, exist_ok=True)
# ====== Imports ======
import streamlit as st
import torch
from sentence_transformers import SentenceTransformer, util
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset, get_dataset_split_names
from PIL import Image
from openai import OpenAI
import comet_llm
from opik import track
# ========== 🔑 API Key ==========
OpenAI.api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY")
os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
# ========== 📥 Load Models ==========
@st.cache_resource(show_spinner=False)
def load_models():
_clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir=os.environ["TRANSFORMERS_CACHE"]
)
_clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir=os.environ["TRANSFORMERS_CACHE"]
)
_text_model = SentenceTransformer(
"all-MiniLM-L6-v2",
cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"]
)
return _clip_model, _clip_processor, _text_model
clip_model, clip_processor, text_model = load_models()
# ========== 📥 Load Dataset ==========
@st.cache_resource(show_spinner=False)
def load_medical_data():
available_splits = get_dataset_split_names("univanxx/3mdbench")
split_to_use = "train" if "train" in available_splits else available_splits[0]
dataset = load_dataset(
"univanxx/3mdbench",
split=split_to_use,
cache_dir=os.environ["HF_DATASETS_CACHE"]
)
return dataset
# Cache dataset image embeddings (takes time, so cached)
@st.cache_data(show_spinner=True)
def embed_dataset_images(_dataset):
features = []
for item in _dataset:
# Load image from URL/path or raw bytes - adapt this if needed
img = item["image"]
inputs_img = clip_processor(images=img, return_tensors="pt")
with torch.no_grad():
feat = clip_model.get_image_features(**inputs_img)
feat /= feat.norm(p=2, dim=-1, keepdim=True)
features.append(feat.cpu())
return torch.cat(features, dim=0)
data = load_medical_data()
dataset_image_features = embed_dataset_images(data)
client = OpenAI(api_key=OpenAI.api_key)
# Temporary debug display
#st.write("Dataset columns:", data.features.keys())
# After seeing the real column name, let's say it's "text" instead of "description":
text_field = "text" if "text" in data.features else list(data.features.keys())[0]
@st.cache_data(show_spinner=False)
def prepare_combined_texts(_dataset):
combined = []
for gc, c in zip(_dataset["general_complaint"], _dataset["complaints"]):
gc_str = gc if gc else ""
c_str = c if c else ""
combined.append(f"General complaint: {gc_str}. Additional details: {c_str}")
return combined
combined_texts = prepare_combined_texts(data)
# Then use dynamic access:
#text_embeddings = embed_texts(data[text_field])
# ========== 🧠 Embedding Function ==========
@st.cache_data(show_spinner=False)
def embed_dataset_texts(_texts):
return text_model.encode(_texts, convert_to_tensor=True)
def embed_query_text(_query):
return text_model.encode([_query], convert_to_tensor=True)[0]
@track
def get_chat_completion_openai(_client, _prompt: str):
return _client.chat.completions.create(
model="gpt-4o", # or "gpt-4" if you need the older GPT-4
messages=[{"role": "user", "content": _prompt}],
temperature=0.5,
max_tokens=425
)
@track
def get_similar_prompt(_query):
text_embeddings = embed_dataset_texts(combined_texts) # cached
query_embedding = embed_query_text(_query) # recalculated each time
cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
top_result = torch.topk(cos_scores, k=1)
_idx = top_result.indices[0].item()
return data[_idx]
# Pick which text column to use
TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs
# ========== 🧑⚕️ App UI ==========
st.title("🩺 Multimodal Medical Chatbot")
query = st.text_input("Enter your medical question or symptom description:")
uploaded_files = st.file_uploader("Upload an image to find similar medical cases:", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
# Add author info in the sidebar
with st.sidebar:
st.markdown("## 👤👤Authors")
st.markdown("**Vasan Iyer**")
st.markdown("**Eric J Giacomucci**")
st.markdown("[GitHub](https://github.com/Vaiy108)")
st.markdown("[LinkedIn](https://linkedin.com/in/vasan-iyer)")
if st.button("Submit") and query:
with st.spinner("Searching medical cases..."):
# Compute similarity
selected = get_similar_prompt(query)
# Show Image
st.image(selected['image'], caption="Most relevant medical image", use_container_width=True)
# Show Text
st.markdown(f"**Case Description:** {selected[TEXT_COLUMN]}")
# GPT Explanation
if OpenAI.api_key:
prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}"
explanation = get_chat_completion_openai(client, prompt)
explanation = explanation.choices[0].message.content
st.markdown(f"### 🤖 Explanation by GPT:\n{explanation}")
else:
st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
if uploaded_files is not None:
with st.spinner("Searching medical cases..."):
st.write(f"Number of files: {len(uploaded_files)}")
if len(uploaded_files) > 0:
print(uploaded_files)
uploaded_file = uploaded_files[0]
st.write(f'uploading file {uploaded_file.name}')
query_image = Image.open(uploaded_file).convert("RGB")
st.image(query_image, caption="Your uploaded image", use_container_width=True)
# Embed uploaded image
inputs = clip_processor(images=query_image, return_tensors="pt")
with torch.no_grad():
query_feat = clip_model.get_image_features(**inputs)
query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True)
# Compute cosine similarity
similarities = (dataset_image_features @ query_feat.T).squeeze(1) # [num_dataset_images]
top_k = 3
top_results = torch.topk(similarities, k=top_k)
st.write(f"Top {top_k} similar medical cases:")
for rank, idx in enumerate(top_results.indices):
score = top_results.values[rank].item()
similar_img = data[int(idx)]['image']
st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True)
st.markdown(f"**Case description:** {data[int(idx)]['complaints']}")
st.caption("This chatbot is for educational purposes only and does not provide medical advice.")
|