extrav / app.py
disohugface's picture
Update app.py
4bf7c22 verified
raw
history blame
1.38 kB
import gradio as gr
import torch
from PIL import Image
from transformers import ColPaliForRetrieval, ColPaliProcessor
import numpy as np
model_name = "vidore/colpali-v1.3-hf"
model = ColPaliForRetrieval.from_pretrained(model_name, torch_dtype=torch.float32).eval()
processor = ColPaliProcessor.from_pretrained(model_name)
def process_image(image):
# Ensure the image is in RGB format
image = image.convert('RGB')
# Process the image
inputs = processor(images=image, return_tensors="pt")
# Generate embeddings
with torch.no_grad():
outputs = model(**inputs)
# Extract embeddings and convert to list
embeddings = outputs.embeddings.squeeze().cpu().numpy().tolist()
# Truncate the embeddings for display purposes
truncated_embeddings = embeddings[:10] # Show only first 10 values
# Prepare the output
output = {
"embedding_sample": truncated_embeddings,
"embedding_length": len(embeddings),
"embedding_shape": list(np.array(embeddings).shape)
}
return output
# Create Gradio interface
demo = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=gr.JSON(),
title="ColPali Image Embedding Generator",
description="Upload an image to generate its embedding using the ColPali model."
)
# Launch the interface
demo.launch()