|
import time |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel |
|
from PIL import Image |
|
import requests |
|
import io |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir=None): |
|
print("π Loading model...") |
|
self.model = CLIPModel.from_pretrained("dazpye/clip-image") |
|
self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image") |
|
|
|
def _load_image(self, image_url): |
|
"""Fetches an image from a URL.""" |
|
try: |
|
print(f"π Fetching image from: {image_url}") |
|
response = requests.get(image_url, timeout=5) |
|
response.raise_for_status() |
|
return Image.open(io.BytesIO(response.content)).convert("RGB") |
|
except Exception as e: |
|
print(f"β Image loading failed: {e}") |
|
return None |
|
|
|
def __call__(self, data): |
|
"""Processes input and runs inference.""" |
|
start_time = time.time() |
|
|
|
print("π₯ Processing input...") |
|
|
|
if "inputs" in data: |
|
data = data["inputs"] |
|
|
|
text = data.get("text", ["default text"]) |
|
image_urls = data.get("images", []) |
|
|
|
images = [self._load_image(url) for url in image_urls if url] |
|
images = [img for img in images if img] |
|
|
|
if not images: |
|
return {"error": "No valid images provided."} |
|
|
|
|
|
inputs = self.processor( |
|
text=text, |
|
images=images, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True |
|
) |
|
|
|
print("π₯οΈ Running inference...") |
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
|
|
logits_per_image = outputs.logits_per_image |
|
probabilities = logits_per_image.softmax(dim=1) |
|
|
|
|
|
predictions = [] |
|
for i, probs in enumerate(probabilities): |
|
sorted_indices = torch.argsort(probs, descending=True) |
|
best_matches = [(text[idx], probs[idx].item()) for idx in sorted_indices[:3]] |
|
predictions.append({"image_index": i, "top_matches": best_matches}) |
|
|
|
total_time = time.time() - start_time |
|
|
|
return { |
|
"predictions": predictions, |
|
"processing_time_seconds": round(total_time, 4) |
|
} |