File size: 2,441 Bytes
c3cdfd6 3c066bf fe492f5 24c40ff fe492f5 3c066bf 4280311 1549f24 24c40ff 3c066bf 7ed3506 1549f24 ec42764 1549f24 7ed3506 1549f24 7ed3506 ec42764 7ed3506 1549f24 3c066bf fe492f5 7ed3506 c3cdfd6 1549f24 24c40ff 1549f24 24c40ff 7ed3506 24c40ff 7ed3506 fe492f5 1549f24 fe492f5 24c40ff 3c066bf 1549f24 c3cdfd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import time
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
import io
class EndpointHandler:
def __init__(self, model_dir=None):
print("π Loading model...")
self.model = CLIPModel.from_pretrained("dazpye/clip-image")
self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image")
def _load_image(self, image_url):
"""Fetches an image from a URL."""
try:
print(f"π Fetching image from: {image_url}")
response = requests.get(image_url, timeout=5)
response.raise_for_status()
return Image.open(io.BytesIO(response.content)).convert("RGB")
except Exception as e:
print(f"β Image loading failed: {e}")
return None
def __call__(self, data):
"""Processes input and runs inference."""
start_time = time.time() # Start timer
print("π₯ Processing input...")
if "inputs" in data:
data = data["inputs"]
text = data.get("text", ["default text"])
image_urls = data.get("images", [])
images = [self._load_image(url) for url in image_urls if url]
images = [img for img in images if img] # Remove failed images
if not images:
return {"error": "No valid images provided."}
# Enable padding & truncation to fix tensor error
inputs = self.processor(
text=text,
images=images,
return_tensors="pt",
padding=True,
truncation=True
)
print("π₯οΈ Running inference...")
with torch.no_grad():
outputs = self.model(**inputs)
# Get scores & find best matches
logits_per_image = outputs.logits_per_image
probabilities = logits_per_image.softmax(dim=1)
# Get top categories per image
predictions = []
for i, probs in enumerate(probabilities):
sorted_indices = torch.argsort(probs, descending=True)
best_matches = [(text[idx], probs[idx].item()) for idx in sorted_indices[:3]] # Get top 3 matches
predictions.append({"image_index": i, "top_matches": best_matches})
total_time = time.time() - start_time # Calculate time taken
return {
"predictions": predictions,
"processing_time_seconds": round(total_time, 4)
} |