import time import torch from transformers import CLIPProcessor, CLIPModel from PIL import Image import requests import io class EndpointHandler: def __init__(self, model_dir=None): print("🔄 Loading model...") self.model = CLIPModel.from_pretrained("dazpye/clip-image") self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image") def _load_image(self, image_url): """Fetches an image from a URL.""" try: print(f"🌐 Fetching image from: {image_url}") response = requests.get(image_url, timeout=5) response.raise_for_status() return Image.open(io.BytesIO(response.content)).convert("RGB") except Exception as e: print(f"❌ Image loading failed: {e}") return None def __call__(self, data): """Processes input and runs inference.""" start_time = time.time() # Start timer print("📥 Processing input...") if "inputs" in data: data = data["inputs"] text = data.get("text", ["default text"]) image_urls = data.get("images", []) images = [self._load_image(url) for url in image_urls if url] images = [img for img in images if img] # Remove failed images if not images: return {"error": "No valid images provided."} # Enable padding & truncation to fix tensor error inputs = self.processor( text=text, images=images, return_tensors="pt", padding=True, truncation=True ) print("🖥️ Running inference...") with torch.no_grad(): outputs = self.model(**inputs) # Get scores & find best matches logits_per_image = outputs.logits_per_image probabilities = logits_per_image.softmax(dim=1) # Get top categories per image predictions = [] for i, probs in enumerate(probabilities): sorted_indices = torch.argsort(probs, descending=True) best_matches = [(text[idx], probs[idx].item()) for idx in sorted_indices[:3]] # Get top 3 matches predictions.append({"image_index": i, "top_matches": best_matches}) total_time = time.time() - start_time # Calculate time taken return { "predictions": predictions, "processing_time_seconds": round(total_time, 4) }