Dr_Q_bot_multimodal / src /streamlit_app.py
vi108's picture
Update src/streamlit_app.py
ec40339 verified
# ================================
# ✅ Cache-Safe Multimodal App
# ================================
import shutil, os
# ====== Force all cache dirs to /tmp (writable in most environments) ======
CACHE_BASE = "/tmp/cache"
os.environ["HF_HOME"] = f"{CACHE_BASE}/hf_home"
os.environ["TRANSFORMERS_CACHE"] = f"{CACHE_BASE}/transformers"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = f"{CACHE_BASE}/sentence_transformers"
os.environ["HF_DATASETS_CACHE"] = f"{CACHE_BASE}/hf_datasets"
os.environ["TORCH_HOME"] = f"{CACHE_BASE}/torch"
os.environ["STREAMLIT_CACHE_DIR"] = f"{CACHE_BASE}/streamlit_cache"
os.environ["STREAMLIT_STATIC_DIR"] = f"{CACHE_BASE}/streamlit_static"
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"
# Create the directories before imports
os.makedirs(os.environ["STREAMLIT_CONFIG_DIR"], exist_ok=True)
# Create the directories before imports
for path in os.environ.values():
if path.startswith(CACHE_BASE):
os.makedirs(path, exist_ok=True)
# ====== Imports ======
import streamlit as st
import torch
from sentence_transformers import SentenceTransformer, util
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset, get_dataset_split_names
from PIL import Image
from openai import OpenAI
import comet_llm
from opik import track
# ========== 🔑 API Key ==========
OpenAI.api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY")
os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE")
# ========== 📥 Load Models ==========
@st.cache_resource(show_spinner=False)
def load_models():
_clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir=os.environ["TRANSFORMERS_CACHE"]
)
_clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir=os.environ["TRANSFORMERS_CACHE"]
)
_text_model = SentenceTransformer(
"all-MiniLM-L6-v2",
cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"]
)
return _clip_model, _clip_processor, _text_model
clip_model, clip_processor, text_model = load_models()
# ========== 📥 Load Dataset ==========
@st.cache_resource(show_spinner=False)
def load_medical_data():
available_splits = get_dataset_split_names("univanxx/3mdbench")
split_to_use = "train" if "train" in available_splits else available_splits[0]
dataset = load_dataset(
"univanxx/3mdbench",
split=split_to_use,
cache_dir=os.environ["HF_DATASETS_CACHE"]
)
return dataset
# Cache dataset image embeddings (takes time, so cached)
@st.cache_data(show_spinner=True)
def embed_dataset_images(_dataset):
features = []
for item in _dataset:
# Load image from URL/path or raw bytes - adapt this if needed
img = item["image"]
inputs_img = clip_processor(images=img, return_tensors="pt")
with torch.no_grad():
feat = clip_model.get_image_features(**inputs_img)
feat /= feat.norm(p=2, dim=-1, keepdim=True)
features.append(feat.cpu())
return torch.cat(features, dim=0)
data = load_medical_data()
dataset_image_features = embed_dataset_images(data)
client = OpenAI(api_key=OpenAI.api_key)
# Temporary debug display
#st.write("Dataset columns:", data.features.keys())
# After seeing the real column name, let's say it's "text" instead of "description":
text_field = "text" if "text" in data.features else list(data.features.keys())[0]
@st.cache_data(show_spinner=False)
def prepare_combined_texts(_dataset):
combined = []
for gc, c in zip(_dataset["general_complaint"], _dataset["complaints"]):
gc_str = gc if gc else ""
c_str = c if c else ""
combined.append(f"General complaint: {gc_str}. Additional details: {c_str}")
return combined
combined_texts = prepare_combined_texts(data)
# Then use dynamic access:
#text_embeddings = embed_texts(data[text_field])
# ========== 🧠 Embedding Function ==========
@st.cache_data(show_spinner=False)
def embed_dataset_texts(_texts):
return text_model.encode(_texts, convert_to_tensor=True)
def embed_query_text(_query):
return text_model.encode([_query], convert_to_tensor=True)[0]
@track
def get_chat_completion_openai(_client, _prompt: str):
return _client.chat.completions.create(
model="gpt-4o", # or "gpt-4" if you need the older GPT-4
messages=[{"role": "user", "content": _prompt}],
temperature=0.5,
max_tokens=425
)
@track
def get_similar_prompt(_query):
text_embeddings = embed_dataset_texts(combined_texts) # cached
query_embedding = embed_query_text(_query) # recalculated each time
cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0]
top_result = torch.topk(cos_scores, k=1)
_idx = top_result.indices[0].item()
return data[_idx]
# Pick which text column to use
TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs
# ========== 🧑‍⚕️ App UI ==========
st.title("🩺 Multimodal Medical Chatbot")
query = st.text_input("Enter your medical question or symptom description:")
uploaded_files = st.file_uploader("Upload an image to find similar medical cases:", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
# Add author info in the sidebar
with st.sidebar:
st.markdown("## 👤👤Authors")
st.markdown("**Vasan Iyer**")
st.markdown("**Eric J Giacomucci**")
st.markdown("[GitHub](https://github.com/Vaiy108)")
st.markdown("[LinkedIn](https://linkedin.com/in/vasan-iyer)")
if st.button("Submit") and query:
with st.spinner("Searching medical cases..."):
# Compute similarity
selected = get_similar_prompt(query)
# Show Image
st.image(selected['image'], caption="Most relevant medical image", use_container_width=True)
# Show Text
st.markdown(f"**Case Description:** {selected[TEXT_COLUMN]}")
# GPT Explanation
if OpenAI.api_key:
prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}"
explanation = get_chat_completion_openai(client, prompt)
explanation = explanation.choices[0].message.content
st.markdown(f"### 🤖 Explanation by GPT:\n{explanation}")
else:
st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
if uploaded_files is not None:
with st.spinner("Searching medical cases..."):
st.write(f"Number of files: {len(uploaded_files)}")
if len(uploaded_files) > 0:
print(uploaded_files)
uploaded_file = uploaded_files[0]
st.write(f'uploading file {uploaded_file.name}')
query_image = Image.open(uploaded_file).convert("RGB")
st.image(query_image, caption="Your uploaded image", use_container_width=True)
# Embed uploaded image
inputs = clip_processor(images=query_image, return_tensors="pt")
with torch.no_grad():
query_feat = clip_model.get_image_features(**inputs)
query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True)
# Compute cosine similarity
similarities = (dataset_image_features @ query_feat.T).squeeze(1) # [num_dataset_images]
top_k = 3
top_results = torch.topk(similarities, k=top_k)
st.write(f"Top {top_k} similar medical cases:")
for rank, idx in enumerate(top_results.indices):
score = top_results.values[rank].item()
similar_img = data[int(idx)]['image']
st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True)
st.markdown(f"**Case description:** {data[int(idx)]['complaints']}")
st.caption("This chatbot is for educational purposes only and does not provide medical advice.")