Spaces:
Sleeping
Sleeping
File size: 7,923 Bytes
0a89a37 a82bbd1 0a89a37 317f15f e146235 0a89a37 317f15f 0a89a37 a82bbd1 0a89a37 a82bbd1 0a89a37 e146235 0a89a37 c8b7285 a82bbd1 0a89a37 c8b7285 97f12dc 0a89a37 b9bdf95 0a89a37 b9bdf95 0a89a37 b9bdf95 0a89a37 b9bdf95 4a259f2 0a89a37 e146235 6f5e256 b9bdf95 6f5e256 b9bdf95 6f5e256 0a89a37 6f5e256 a82bbd1 c8b7285 a82bbd1 61e5bfd a82bbd1 e146235 0a89a37 a82bbd1 61e5bfd a82bbd1 b9bdf95 61e5bfd 97f12dc b9bdf95 a82bbd1 b9bdf95 97f12dc 3314fb6 97f12dc dc78eca b9bdf95 9ea931a b9bdf95 9ea931a 1230889 b9bdf95 1230889 234ca8e c8b7285 a9986c4 1230889 a9986c4 317f15f c8b7285 317f15f 319c9f8 3320c6e 855ea3c 5e132cf 855ea3c 319c9f8 855ea3c 319c9f8 855ea3c 319c9f8 855ea3c 319c9f8 855ea3c 319c9f8 234ca8e a82bbd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
# ================================
# β
Cache-Safe Multimodal App
# ================================
import shutil, os
# ====== Force all cache dirs to /tmp (writable in most environments) ======
CACHE_BASE = "/tmp/cache"
os.environ["HF_HOME"] = f"{CACHE_BASE}/hf_home"
os.environ["TRANSFORMERS_CACHE"] = f"{CACHE_BASE}/transformers"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = f"{CACHE_BASE}/sentence_transformers"
os.environ["HF_DATASETS_CACHE"] = f"{CACHE_BASE}/hf_datasets"
os.environ["TORCH_HOME"] = f"{CACHE_BASE}/torch"
os.environ["STREAMLIT_CACHE_DIR"] = f"{CACHE_BASE}/streamlit_cache"
os.environ["STREAMLIT_STATIC_DIR"] = f"{CACHE_BASE}/streamlit_static"
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"
# Create the directories before imports
os.makedirs(os.environ["STREAMLIT_CONFIG_DIR"], exist_ok=True)
# Create the directories before imports
for path in os.environ.values():
if path.startswith(CACHE_BASE):
os.makedirs(path, exist_ok=True)
# ====== Imports ======
import streamlit as st
import torch
from sentence_transformers import SentenceTransformer, util
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset, get_dataset_split_names
from PIL import Image
from openai import OpenAI
import comet_llm
from opik import track
# ========== π API Key ==========
OpenAI.api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY")
os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
# ========== π₯ Load Models ==========
@st.cache_resource(show_spinner=False)
def load_models():
_clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir=os.environ["TRANSFORMERS_CACHE"]
)
_clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir=os.environ["TRANSFORMERS_CACHE"]
)
_text_model = SentenceTransformer(
"all-MiniLM-L6-v2",
cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"]
)
return _clip_model, _clip_processor, _text_model
clip_model, clip_processor, text_model = load_models()
# ========== π₯ Load Dataset ==========
@st.cache_resource(show_spinner=False)
def load_medical_data():
available_splits = get_dataset_split_names("univanxx/3mdbench")
split_to_use = "train" if "train" in available_splits else available_splits[0]
dataset = load_dataset(
"univanxx/3mdbench",
split=split_to_use,
cache_dir=os.environ["HF_DATASETS_CACHE"]
)
return dataset
# Cache dataset image embeddings (takes time, so cached)
@st.cache_data(show_spinner=True)
def embed_dataset_images(_dataset):
features = []
for item in _dataset:
# Load image from URL/path or raw bytes - adapt this if needed
img = item["image"]
inputs_img = clip_processor(images=img, return_tensors="pt")
with torch.no_grad():
feat = clip_model.get_image_features(**inputs_img)
feat /= feat.norm(p=2, dim=-1, keepdim=True)
features.append(feat.cpu())
return torch.cat(features, dim=0)
data = load_medical_data()
dataset_image_features = embed_dataset_images(data)
client = OpenAI(api_key=OpenAI.api_key)
# Temporary debug display
#st.write("Dataset columns:", data.features.keys())
# After seeing the real column name, let's say it's "text" instead of "description":
text_field = "text" if "text" in data.features else list(data.features.keys())[0]
@st.cache_data(show_spinner=False)
def prepare_combined_texts(_dataset):
combined = []
for gc, c in zip(_dataset["general_complaint"], _dataset["complaints"]):
gc_str = gc if gc else ""
c_str = c if c else ""
combined.append(f"General complaint: {gc_str}. Additional details: {c_str}")
return combined
combined_texts = prepare_combined_texts(data)
# Then use dynamic access:
#text_embeddings = embed_texts(data[text_field])
# ========== π§ Embedding Function ==========
@st.cache_data(show_spinner=False)
def embed_dataset_texts(_texts):
return text_model.encode(_texts, convert_to_tensor=True)
def embed_query_text(_query):
return text_model.encode([_query], convert_to_tensor=True)[0]
@track
def get_chat_completion_openai(_client, _prompt: str):
return _client.chat.completions.create(
model="gpt-4o", # or "gpt-4" if you need the older GPT-4
messages=[{"role": "user", "content": _prompt}],
temperature=0.5,
max_tokens=425
)
@track
def get_similar_prompt(_query):
text_embeddings = embed_dataset_texts(combined_texts) # cached
query_embedding = embed_query_text(_query) # recalculated each time
cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
top_result = torch.topk(cos_scores, k=1)
_idx = top_result.indices[0].item()
return data[_idx]
# Pick which text column to use
TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs
# ========== π§ββοΈ App UI ==========
st.title("π©Ί Multimodal Medical Chatbot")
query = st.text_input("Enter your medical question or symptom description:")
uploaded_files = st.file_uploader("Upload an image to find similar medical cases:", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
# Add author info in the sidebar
with st.sidebar:
st.markdown("## π€π€Authors")
st.markdown("**Vasan Iyer**")
st.markdown("**Eric J Giacomucci**")
st.markdown("[GitHub](https://github.com/Vaiy108)")
st.markdown("[LinkedIn](https://linkedin.com/in/vasan-iyer)")
if st.button("Submit") and query:
with st.spinner("Searching medical cases..."):
# Compute similarity
selected = get_similar_prompt(query)
# Show Image
st.image(selected['image'], caption="Most relevant medical image", use_container_width=True)
# Show Text
st.markdown(f"**Case Description:** {selected[TEXT_COLUMN]}")
# GPT Explanation
if OpenAI.api_key:
prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}"
explanation = get_chat_completion_openai(client, prompt)
explanation = explanation.choices[0].message.content
st.markdown(f"### π€ Explanation by GPT:\n{explanation}")
else:
st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
if uploaded_files is not None:
with st.spinner("Searching medical cases..."):
st.write(f"Number of files: {len(uploaded_files)}")
if len(uploaded_files) > 0:
print(uploaded_files)
uploaded_file = uploaded_files[0]
st.write(f'uploading file {uploaded_file.name}')
query_image = Image.open(uploaded_file).convert("RGB")
st.image(query_image, caption="Your uploaded image", use_container_width=True)
# Embed uploaded image
inputs = clip_processor(images=query_image, return_tensors="pt")
with torch.no_grad():
query_feat = clip_model.get_image_features(**inputs)
query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True)
# Compute cosine similarity
similarities = (dataset_image_features @ query_feat.T).squeeze(1) # [num_dataset_images]
top_k = 3
top_results = torch.topk(similarities, k=top_k)
st.write(f"Top {top_k} similar medical cases:")
for rank, idx in enumerate(top_results.indices):
score = top_results.values[rank].item()
similar_img = data[int(idx)]['image']
st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True)
st.markdown(f"**Case description:** {data[int(idx)]['complaints']}")
st.caption("This chatbot is for educational purposes only and does not provide medical advice.") |