Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
from transformers import pipeline | |
import io | |
import torch # Import PyTorch | |
# --- Configuration --- | |
# Specify the model | |
MODEL_NAME = "microsoft/maira-2" | |
# --- Model Loading (using pipeline) --- | |
# Cache the pipeline for performance | |
def load_pipeline(): | |
"""Loads the VQA pipeline.""" | |
try: | |
# Explicitly set device if CUDA is available, otherwise use CPU | |
device = 0 if torch.cuda.is_available() else -1 # Use torch.cuda | |
vqa_pipeline = pipeline("visual-question-answering", model=MODEL_NAME, device=device) # Add device | |
return vqa_pipeline | |
except Exception as e: | |
st.error(f"Error loading pipeline: {e}") | |
return None | |
# --- Image Preprocessing (Keep as bytes) --- | |
def prepare_image(image): | |
"""Prepares the PIL Image object for the pipeline (handles RGBA).""" | |
image_bytes = io.BytesIO() | |
if image.mode == "RGBA": | |
image = image.convert("RGB") | |
image.save(image_bytes, format="JPEG") | |
return image_bytes.getvalue() # Return bytes directly | |
# --- Streamlit App --- | |
def main(): | |
st.title("Chest X-ray Analysis with Maira-2 (Transformers Pipeline)") | |
st.write("Upload a chest X-ray image. This app uses the Maira-2 model via the Transformers library.") | |
vqa_pipeline = load_pipeline() | |
if vqa_pipeline is None: | |
st.warning("Pipeline not loaded. Predictions will not be available.") | |
return | |
uploaded_file = st.file_uploader("Choose a chest X-ray image (JPG, PNG)", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
with st.spinner("Analyzing image with Maira-2..."): | |
image_data = prepare_image(image) | |
try: | |
results = vqa_pipeline( | |
image=image_data, # Pass the image bytes | |
question="Analyze this chest X-ray image and provide detailed findings. Include any abnormalities, their locations, and potential diagnoses. Be as specific as possible.", | |
) | |
if results: # Handle results (list of dicts) | |
if isinstance(results, list) and len(results) > 0: | |
best_answer = max(results, key=lambda x: x.get('score', 0)) | |
if 'answer' in best_answer: | |
st.subheader("Findings:") | |
st.write(best_answer['answer']) | |
else: | |
st.warning("Could not find 'answer' in results.") | |
else: | |
st.warning("Unexpected result format.") | |
except Exception as e: | |
st.error(f"An error occurred during analysis: {e}") | |
else: | |
st.write("Please upload an image.") | |
st.write("---") | |
st.write("Disclaimer: For informational purposes only. Not medical advice.") | |
if __name__ == "__main__": | |
main() |