Spaces:
Running
Running
File size: 3,704 Bytes
d03234f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
### import packages
import torch
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
import streamlit as st
from PIL import Image
import os
### write access token in secrets
token = os.environ.get('HF_TOKEN')
### choose a paligemma model
# See https://huggingface.co/collections/google/paligemma-2-release-67500e1e1dbfdd4dee27ba48
model_id = "google/paligemma2-3b-pt-896"
@st.cache_resource
def model_setup(model_id):
"""
Sets up the model with @st.cache_resource to cache the function.
Args:
model_id: one of the paligemma models
Return:
model: from PaliGemmaForConditionalGeneration.from_pretrained
processor: from PaliGemmaProcessor.from_pretrained
"""
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id,torch_dtype=torch.bfloat16,device_map="auto",token=token).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id,token=token)
return model,processor
def run_model(prompt,image):
"""
Performs inference on user's prompt and image
Args:
prompt: user prompt or task
image: user's uploaded image
Returns:
output text
"""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=1000, do_sample=False)
generation = generation[0][input_len:]
return processor.decode(generation, skip_special_tokens=True)
def initialize():
"""
Initializes chat history
"""
st.session_state.messages = []
### load model
model,processor = model_setup(model_id)
### upload a file
uploaded_file = st.file_uploader("Choose an image",on_change=initialize)
if uploaded_file:
st.image(uploaded_file)
image = Image.open(uploaded_file).convert("RGB")
# tasks: Caption by default. Accept user prompt only when selected
task = st.radio(
"Task",
tuple(['Caption','OCR','Segment','Enter your prompt']),
horizontal=True)
# display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if task == 'Enter your prompt':
if prompt := st.chat_input("Type here!",key="user_prompt"):
# display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# run the VLM
response = run_model(prompt,image)
# display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
else:
# display user message in chat message container
with st.chat_message("user"):
st.markdown(task)
# add user message to chat history
st.session_state.messages.append({"role": "user", "content": task})
# run the VLM
response = run_model(task,image)
# display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response}) |