myyim's picture
Upload 3 files
d03234f verified
### import packages
import torch
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
import streamlit as st
from PIL import Image
import os
### write access token in secrets
token = os.environ.get('HF_TOKEN')
### choose a paligemma model
# See https://huggingface.co/collections/google/paligemma-2-release-67500e1e1dbfdd4dee27ba48
model_id = "google/paligemma2-3b-pt-896"
@st.cache_resource
def model_setup(model_id):
"""
Sets up the model with @st.cache_resource to cache the function.
Args:
model_id: one of the paligemma models
Return:
model: from PaliGemmaForConditionalGeneration.from_pretrained
processor: from PaliGemmaProcessor.from_pretrained
"""
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id,torch_dtype=torch.bfloat16,device_map="auto",token=token).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id,token=token)
return model,processor
def run_model(prompt,image):
"""
Performs inference on user's prompt and image
Args:
prompt: user prompt or task
image: user's uploaded image
Returns:
output text
"""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=1000, do_sample=False)
generation = generation[0][input_len:]
return processor.decode(generation, skip_special_tokens=True)
def initialize():
"""
Initializes chat history
"""
st.session_state.messages = []
### load model
model,processor = model_setup(model_id)
### upload a file
uploaded_file = st.file_uploader("Choose an image",on_change=initialize)
if uploaded_file:
st.image(uploaded_file)
image = Image.open(uploaded_file).convert("RGB")
# tasks: Caption by default. Accept user prompt only when selected
task = st.radio(
"Task",
tuple(['Caption','OCR','Segment','Enter your prompt']),
horizontal=True)
# display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if task == 'Enter your prompt':
if prompt := st.chat_input("Type here!",key="user_prompt"):
# display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# run the VLM
response = run_model(prompt,image)
# display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
else:
# display user message in chat message container
with st.chat_message("user"):
st.markdown(task)
# add user message to chat history
st.session_state.messages.append({"role": "user", "content": task})
# run the VLM
response = run_model(task,image)
# display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})