Spaces:
Running
Running
# ================================ | |
# ✅ Cache-Safe Multimodal App | |
# ================================ | |
import shutil, os | |
# ====== Force all cache dirs to /tmp (writable in most environments) ====== | |
CACHE_BASE = "/tmp/cache" | |
os.environ["HF_HOME"] = f"{CACHE_BASE}/hf_home" | |
os.environ["TRANSFORMERS_CACHE"] = f"{CACHE_BASE}/transformers" | |
os.environ["SENTENCE_TRANSFORMERS_HOME"] = f"{CACHE_BASE}/sentence_transformers" | |
os.environ["HF_DATASETS_CACHE"] = f"{CACHE_BASE}/hf_datasets" | |
os.environ["TORCH_HOME"] = f"{CACHE_BASE}/torch" | |
os.environ["STREAMLIT_CACHE_DIR"] = f"{CACHE_BASE}/streamlit_cache" | |
os.environ["STREAMLIT_STATIC_DIR"] = f"{CACHE_BASE}/streamlit_static" | |
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit" | |
# Create the directories before imports | |
os.makedirs(os.environ["STREAMLIT_CONFIG_DIR"], exist_ok=True) | |
# Create the directories before imports | |
for path in os.environ.values(): | |
if path.startswith(CACHE_BASE): | |
os.makedirs(path, exist_ok=True) | |
# ====== Imports ====== | |
import streamlit as st | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
from transformers import CLIPProcessor, CLIPModel | |
from datasets import load_dataset, get_dataset_split_names | |
from PIL import Image | |
from openai import OpenAI | |
import comet_llm | |
from opik import track | |
# ========== 🔑 API Key ========== | |
OpenAI.api_key = os.getenv("OPENAI_API_KEY") | |
os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY") | |
os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE") | |
# ========== 📥 Load Models ========== | |
def load_models(): | |
_clip_model = CLIPModel.from_pretrained( | |
"openai/clip-vit-base-patch32", | |
cache_dir=os.environ["TRANSFORMERS_CACHE"] | |
) | |
_clip_processor = CLIPProcessor.from_pretrained( | |
"openai/clip-vit-base-patch32", | |
cache_dir=os.environ["TRANSFORMERS_CACHE"] | |
) | |
_text_model = SentenceTransformer( | |
"all-MiniLM-L6-v2", | |
cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"] | |
) | |
return _clip_model, _clip_processor, _text_model | |
clip_model, clip_processor, text_model = load_models() | |
# ========== 📥 Load Dataset ========== | |
def load_medical_data(): | |
available_splits = get_dataset_split_names("univanxx/3mdbench") | |
split_to_use = "train" if "train" in available_splits else available_splits[0] | |
dataset = load_dataset( | |
"univanxx/3mdbench", | |
split=split_to_use, | |
cache_dir=os.environ["HF_DATASETS_CACHE"] | |
) | |
return dataset | |
# Cache dataset image embeddings (takes time, so cached) | |
def embed_dataset_images(_dataset): | |
features = [] | |
for item in _dataset: | |
# Load image from URL/path or raw bytes - adapt this if needed | |
img = item["image"] | |
inputs_img = clip_processor(images=img, return_tensors="pt") | |
with torch.no_grad(): | |
feat = clip_model.get_image_features(**inputs_img) | |
feat /= feat.norm(p=2, dim=-1, keepdim=True) | |
features.append(feat.cpu()) | |
return torch.cat(features, dim=0) | |
data = load_medical_data() | |
dataset_image_features = embed_dataset_images(data) | |
client = OpenAI(api_key=OpenAI.api_key) | |
# Temporary debug display | |
#st.write("Dataset columns:", data.features.keys()) | |
# After seeing the real column name, let's say it's "text" instead of "description": | |
text_field = "text" if "text" in data.features else list(data.features.keys())[0] | |
def prepare_combined_texts(_dataset): | |
combined = [] | |
for gc, c in zip(_dataset["general_complaint"], _dataset["complaints"]): | |
gc_str = gc if gc else "" | |
c_str = c if c else "" | |
combined.append(f"General complaint: {gc_str}. Additional details: {c_str}") | |
return combined | |
combined_texts = prepare_combined_texts(data) | |
# Then use dynamic access: | |
#text_embeddings = embed_texts(data[text_field]) | |
# ========== 🧠 Embedding Function ========== | |
def embed_dataset_texts(_texts): | |
return text_model.encode(_texts, convert_to_tensor=True) | |
def embed_query_text(_query): | |
return text_model.encode([_query], convert_to_tensor=True)[0] | |
def get_chat_completion_openai(_client, _prompt: str): | |
return _client.chat.completions.create( | |
model="gpt-4o", # or "gpt-4" if you need the older GPT-4 | |
messages=[{"role": "user", "content": _prompt}], | |
temperature=0.5, | |
max_tokens=425 | |
) | |
def get_similar_prompt(_query): | |
text_embeddings = embed_dataset_texts(combined_texts) # cached | |
query_embedding = embed_query_text(_query) # recalculated each time | |
cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0] | |
top_result = torch.topk(cos_scores, k=1) | |
_idx = top_result.indices[0].item() | |
return data[_idx] | |
# Pick which text column to use | |
TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs | |
# ========== 🧑⚕️ App UI ========== | |
st.title("🩺 Multimodal Medical Chatbot") | |
query = st.text_input("Enter your medical question or symptom description:") | |
uploaded_files = st.file_uploader("Upload an image to find similar medical cases:", type=["png", "jpg", "jpeg"], accept_multiple_files=True) | |
# Add author info in the sidebar | |
with st.sidebar: | |
st.markdown("## 👤👤Authors") | |
st.markdown("**Vasan Iyer**") | |
st.markdown("**Eric J Giacomucci**") | |
st.markdown("[GitHub](https://github.com/Vaiy108)") | |
st.markdown("[LinkedIn](https://linkedin.com/in/vasan-iyer)") | |
if st.button("Submit") and query: | |
with st.spinner("Searching medical cases..."): | |
# Compute similarity | |
selected = get_similar_prompt(query) | |
# Show Image | |
st.image(selected['image'], caption="Most relevant medical image", use_container_width=True) | |
# Show Text | |
st.markdown(f"**Case Description:** {selected[TEXT_COLUMN]}") | |
# GPT Explanation | |
if OpenAI.api_key: | |
prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}" | |
explanation = get_chat_completion_openai(client, prompt) | |
explanation = explanation.choices[0].message.content | |
st.markdown(f"### 🤖 Explanation by GPT:\n{explanation}") | |
else: | |
st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.") | |
if uploaded_files is not None: | |
with st.spinner("Searching medical cases..."): | |
st.write(f"Number of files: {len(uploaded_files)}") | |
if len(uploaded_files) > 0: | |
print(uploaded_files) | |
uploaded_file = uploaded_files[0] | |
st.write(f'uploading file {uploaded_file.name}') | |
query_image = Image.open(uploaded_file).convert("RGB") | |
st.image(query_image, caption="Your uploaded image", use_container_width=True) | |
# Embed uploaded image | |
inputs = clip_processor(images=query_image, return_tensors="pt") | |
with torch.no_grad(): | |
query_feat = clip_model.get_image_features(**inputs) | |
query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True) | |
# Compute cosine similarity | |
similarities = (dataset_image_features @ query_feat.T).squeeze(1) # [num_dataset_images] | |
top_k = 3 | |
top_results = torch.topk(similarities, k=top_k) | |
st.write(f"Top {top_k} similar medical cases:") | |
for rank, idx in enumerate(top_results.indices): | |
score = top_results.values[rank].item() | |
similar_img = data[int(idx)]['image'] | |
st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True) | |
st.markdown(f"**Case description:** {data[int(idx)]['complaints']}") | |
st.caption("This chatbot is for educational purposes only and does not provide medical advice.") | |