|
import streamlit as st |
|
from PIL import Image |
|
import torch |
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model_name = "google/paligemma2-3b-mix-224" |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = AutoModelForVision2Seq.from_pretrained(model_name) |
|
return processor, model |
|
|
|
processor, model = load_model() |
|
|
|
|
|
st.title("🖼️ Image Q&A using PaliGemma") |
|
|
|
uploaded_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"]) |
|
|
|
if uploaded_file: |
|
image = Image.open(uploaded_file).convert("RGB") |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
question = st.text_input("Ask a question about the image:") |
|
if question: |
|
|
|
inputs = processor(text=question, images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
output = model.generate(**inputs) |
|
|
|
answer = processor.batch_decode(output, skip_special_tokens=True)[0] |
|
st.success(f"Answer: {answer}") |
|
|