File size: 2,716 Bytes
45f9b3d
44157d6
 
0dc6935
cf83b3d
ef13ec4
44157d6
 
cf83b3d
ef13ec4
 
3373ce1
cf83b3d
 
44157d6
 
 
e0a390e
 
 
 
 
 
44157d6
e0a390e
44157d6
 
 
 
 
 
 
 
 
fa73fe7
cf83b3d
 
 
e0a390e
cf83b3d
 
 
 
 
 
 
 
 
fa73fe7
44157d6
cf83b3d
3373ce1
44157d6
 
cf83b3d
44157d6
 
 
cf83b3d
 
 
44157d6
 
cf83b3d
44157d6
 
 
 
cf83b3d
44157d6
 
 
 
 
cf83b3d
 
44157d6
 
3373ce1
 
44157d6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import torch
import os

# Load the model and processor
model_id = "google/paligemma-3b-mix-224"
HF_TOKEN = os.getenv('HF_TOKEN')
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, token=HF_TOKEN).eval()
processor = AutoProcessor.from_pretrained(model_id, token=HF_TOKEN)

def generate_caption(image, prompt="What is in this image?", max_tokens=100):
    """Generate image description"""
    if image is None:
        return "Please upload an image."
    
    # Update UI to show processing
    gr.Info("Analysis starting. This may take up to 119 seconds.")
    
    # Modify prompt to include image token
    full_prompt = "<image> " + prompt
    
    # Preprocess inputs
    model_inputs = processor(text=full_prompt, images=image, return_tensors="pt")
    input_len = model_inputs["input_ids"].shape[-1]
    
    # Generate caption
    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=max_tokens, do_sample=False)
        generation = generation[0][input_len:]
        decoded = processor.decode(generation, skip_special_tokens=True)
    
    return decoded

# Load local example images
def load_local_images():
    """Load images from the repository"""
    image_files = ['image1.jpg', 'image2.jpg', 'image3.jpg']
    local_images = []
    for img_file in image_files:
        try:
            img_path = os.path.join('.', img_file)
            if os.path.exists(img_path):
                local_images.append(Image.open(img_path))
        except Exception as e:
            print(f"Could not load {img_file}: {e}")
    return local_images

# Prepare example images
EXAMPLE_IMAGES = load_local_images()

# Create Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# PaliGemma Image Analysis")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Upload or Select Image")
            custom_prompt = gr.Textbox(label="Custom Prompt", value="What is in this image?")
            submit_btn = gr.Button("Analyze Image")
        
        with gr.Column():
            output_text = gr.Textbox(label="Image Description")
    
    # Connect components
    submit_btn.click(
        fn=generate_caption, 
        inputs=[input_image, custom_prompt], 
        outputs=output_text
    )
    
    # Add example images
    gr.Examples(
        examples=[[img, "What is in this image?"] for img in EXAMPLE_IMAGES],
        inputs=[input_image, custom_prompt],
        fn=generate_caption,
        outputs=output_text
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()