clip-image / handler.py
dazpye's picture
Update handler.py
c3cdfd6 verified
import time
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
import io
class EndpointHandler:
def __init__(self, model_dir=None):
print("πŸ”„ Loading model...")
self.model = CLIPModel.from_pretrained("dazpye/clip-image")
self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image")
def _load_image(self, image_url):
"""Fetches an image from a URL."""
try:
print(f"🌐 Fetching image from: {image_url}")
response = requests.get(image_url, timeout=5)
response.raise_for_status()
return Image.open(io.BytesIO(response.content)).convert("RGB")
except Exception as e:
print(f"❌ Image loading failed: {e}")
return None
def __call__(self, data):
"""Processes input and runs inference."""
start_time = time.time() # Start timer
print("πŸ“₯ Processing input...")
if "inputs" in data:
data = data["inputs"]
text = data.get("text", ["default text"])
image_urls = data.get("images", [])
images = [self._load_image(url) for url in image_urls if url]
images = [img for img in images if img] # Remove failed images
if not images:
return {"error": "No valid images provided."}
# Enable padding & truncation to fix tensor error
inputs = self.processor(
text=text,
images=images,
return_tensors="pt",
padding=True,
truncation=True
)
print("πŸ–₯️ Running inference...")
with torch.no_grad():
outputs = self.model(**inputs)
# Get scores & find best matches
logits_per_image = outputs.logits_per_image
probabilities = logits_per_image.softmax(dim=1)
# Get top categories per image
predictions = []
for i, probs in enumerate(probabilities):
sorted_indices = torch.argsort(probs, descending=True)
best_matches = [(text[idx], probs[idx].item()) for idx in sorted_indices[:3]] # Get top 3 matches
predictions.append({"image_index": i, "top_matches": best_matches})
total_time = time.time() - start_time # Calculate time taken
return {
"predictions": predictions,
"processing_time_seconds": round(total_time, 4)
}