File size: 2,441 Bytes
c3cdfd6
3c066bf
 
fe492f5
24c40ff
fe492f5
3c066bf
4280311
1549f24
24c40ff
3c066bf
 
 
7ed3506
1549f24
ec42764
1549f24
7ed3506
1549f24
7ed3506
ec42764
7ed3506
1549f24
3c066bf
fe492f5
7ed3506
c3cdfd6
 
1549f24
24c40ff
1549f24
 
 
 
 
24c40ff
7ed3506
 
24c40ff
7ed3506
 
fe492f5
1549f24
 
 
 
 
 
 
 
fe492f5
24c40ff
3c066bf
 
 
1549f24
 
 
 
 
 
 
 
 
 
 
c3cdfd6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import time
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
import io

class EndpointHandler:
    def __init__(self, model_dir=None):
        print("πŸ”„ Loading model...")
        self.model = CLIPModel.from_pretrained("dazpye/clip-image")
        self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image")

    def _load_image(self, image_url):
        """Fetches an image from a URL."""
        try:
            print(f"🌐 Fetching image from: {image_url}")
            response = requests.get(image_url, timeout=5)
            response.raise_for_status()
            return Image.open(io.BytesIO(response.content)).convert("RGB")
        except Exception as e:
            print(f"❌ Image loading failed: {e}")
        return None

    def __call__(self, data):
        """Processes input and runs inference."""
        start_time = time.time()  # Start timer

        print("πŸ“₯ Processing input...")

        if "inputs" in data:
            data = data["inputs"]

        text = data.get("text", ["default text"])
        image_urls = data.get("images", [])

        images = [self._load_image(url) for url in image_urls if url]
        images = [img for img in images if img]  # Remove failed images

        if not images:
            return {"error": "No valid images provided."}

        # Enable padding & truncation to fix tensor error
        inputs = self.processor(
            text=text, 
            images=images, 
            return_tensors="pt", 
            padding=True, 
            truncation=True
        )

        print("πŸ–₯️ Running inference...")
        with torch.no_grad():
            outputs = self.model(**inputs)

        # Get scores & find best matches
        logits_per_image = outputs.logits_per_image
        probabilities = logits_per_image.softmax(dim=1)

        # Get top categories per image
        predictions = []
        for i, probs in enumerate(probabilities):
            sorted_indices = torch.argsort(probs, descending=True)
            best_matches = [(text[idx], probs[idx].item()) for idx in sorted_indices[:3]]  # Get top 3 matches
            predictions.append({"image_index": i, "top_matches": best_matches})

        total_time = time.time() - start_time  # Calculate time taken

        return {
            "predictions": predictions,
            "processing_time_seconds": round(total_time, 4)
        }