xray-analysis / app.py
skjaini's picture
Update app.py
2851f3a verified
import streamlit as st
from PIL import Image
from transformers import pipeline
import io
import torch # Import PyTorch
# --- Configuration ---
# Specify the model
MODEL_NAME = "microsoft/maira-2"
# --- Model Loading (using pipeline) ---
@st.cache_resource # Cache the pipeline for performance
def load_pipeline():
"""Loads the VQA pipeline."""
try:
# Explicitly set device if CUDA is available, otherwise use CPU
device = 0 if torch.cuda.is_available() else -1 # Use torch.cuda
vqa_pipeline = pipeline("visual-question-answering", model=MODEL_NAME, device=device) # Add device
return vqa_pipeline
except Exception as e:
st.error(f"Error loading pipeline: {e}")
return None
# --- Image Preprocessing (Keep as bytes) ---
def prepare_image(image):
"""Prepares the PIL Image object for the pipeline (handles RGBA)."""
image_bytes = io.BytesIO()
if image.mode == "RGBA":
image = image.convert("RGB")
image.save(image_bytes, format="JPEG")
return image_bytes.getvalue() # Return bytes directly
# --- Streamlit App ---
def main():
st.title("Chest X-ray Analysis with Maira-2 (Transformers Pipeline)")
st.write("Upload a chest X-ray image. This app uses the Maira-2 model via the Transformers library.")
vqa_pipeline = load_pipeline()
if vqa_pipeline is None:
st.warning("Pipeline not loaded. Predictions will not be available.")
return
uploaded_file = st.file_uploader("Choose a chest X-ray image (JPG, PNG)", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
with st.spinner("Analyzing image with Maira-2..."):
image_data = prepare_image(image)
try:
results = vqa_pipeline(
image=image_data, # Pass the image bytes
question="Analyze this chest X-ray image and provide detailed findings. Include any abnormalities, their locations, and potential diagnoses. Be as specific as possible.",
)
if results: # Handle results (list of dicts)
if isinstance(results, list) and len(results) > 0:
best_answer = max(results, key=lambda x: x.get('score', 0))
if 'answer' in best_answer:
st.subheader("Findings:")
st.write(best_answer['answer'])
else:
st.warning("Could not find 'answer' in results.")
else:
st.warning("Unexpected result format.")
except Exception as e:
st.error(f"An error occurred during analysis: {e}")
else:
st.write("Please upload an image.")
st.write("---")
st.write("Disclaimer: For informational purposes only. Not medical advice.")
if __name__ == "__main__":
main()