|
import gradio as gr |
|
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration |
|
from PIL import Image |
|
import torch |
|
import os |
|
|
|
|
|
model_id = "google/paligemma-3b-mix-224" |
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, token=HF_TOKEN).eval() |
|
processor = AutoProcessor.from_pretrained(model_id, token=HF_TOKEN) |
|
|
|
def generate_caption(image, prompt="What is in this image?", max_tokens=100): |
|
"""Generate image description""" |
|
if image is None: |
|
return "Please upload an image." |
|
|
|
|
|
gr.Info("Analysis starting. This may take up to 119 seconds.") |
|
|
|
|
|
full_prompt = "<image> " + prompt |
|
|
|
|
|
model_inputs = processor(text=full_prompt, images=image, return_tensors="pt") |
|
input_len = model_inputs["input_ids"].shape[-1] |
|
|
|
|
|
with torch.inference_mode(): |
|
generation = model.generate(**model_inputs, max_new_tokens=max_tokens, do_sample=False) |
|
generation = generation[0][input_len:] |
|
decoded = processor.decode(generation, skip_special_tokens=True) |
|
|
|
return decoded |
|
|
|
|
|
def load_local_images(): |
|
"""Load images from the repository""" |
|
image_files = ['image1.jpg', 'image2.jpg', 'image3.jpg'] |
|
local_images = [] |
|
for img_file in image_files: |
|
try: |
|
img_path = os.path.join('.', img_file) |
|
if os.path.exists(img_path): |
|
local_images.append(Image.open(img_path)) |
|
except Exception as e: |
|
print(f"Could not load {img_file}: {e}") |
|
return local_images |
|
|
|
|
|
EXAMPLE_IMAGES = load_local_images() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# PaliGemma Image Analysis") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil", label="Upload or Select Image") |
|
custom_prompt = gr.Textbox(label="Custom Prompt", value="What is in this image?") |
|
submit_btn = gr.Button("Analyze Image") |
|
|
|
with gr.Column(): |
|
output_text = gr.Textbox(label="Image Description") |
|
|
|
|
|
submit_btn.click( |
|
fn=generate_caption, |
|
inputs=[input_image, custom_prompt], |
|
outputs=output_text |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[[img, "What is in this image?"] for img in EXAMPLE_IMAGES], |
|
inputs=[input_image, custom_prompt], |
|
fn=generate_caption, |
|
outputs=output_text |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |