dazpye commited on
Commit
7ed3506
Β·
verified Β·
1 Parent(s): ba90fc6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +18 -45
handler.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  from transformers import CLIPProcessor, CLIPModel
3
  from PIL import Image
4
  import requests
5
- import base64
6
  import io
7
 
8
  class EndpointHandler:
@@ -11,61 +10,35 @@ class EndpointHandler:
11
  self.model = CLIPModel.from_pretrained("dazpye/clip-image")
12
  self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image")
13
 
14
- def _load_image(self, image_data):
15
- """Fetches an image from a URL or decodes a base64 image."""
16
  try:
17
- if isinstance(image_data, str):
18
- if image_data.startswith("http"):
19
- # Fetch image from URL with User-Agent to bypass restrictions
20
- print(f"🌐 Fetching image from: {image_data}")
21
- headers = {
22
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
23
- }
24
- response = requests.get(image_data, headers=headers, timeout=5)
25
- print(f"βœ… HTTP Status Code: {response.status_code}")
26
-
27
- if response.status_code == 200:
28
- image_bytes = io.BytesIO(response.content)
29
- return Image.open(image_bytes).convert("RGB")
30
- else:
31
- print(f"❌ Failed to fetch image: HTTP {response.status_code}")
32
-
33
- else:
34
- # Handle base64-encoded image
35
- print("πŸ“Έ Decoding base64 image...")
36
- return Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
37
-
38
  except Exception as e:
39
- print(f"⚠️ Exception in image loading: {e}")
40
-
41
  return None # Return None if image loading fails
42
 
43
  def __call__(self, data):
44
- """Main inference function Hugging Face expects."""
45
- print("πŸ“₯ Processing input...")
46
 
47
- if "inputs" in data:
48
- data = data["inputs"]
49
 
50
- text = data.get("text", ["default caption"]) # Default text
51
- images = data.get("images", []) # List of images
52
 
53
- # Convert image URLs or base64 strings to PIL images
54
- pil_images = [self._load_image(img) for img in images if img]
55
- pil_images = [img for img in pil_images if img] # Remove None values
56
 
57
- if not pil_images:
58
- print("❌ No valid images provided. Check URLs or base64 encoding.")
59
- return {"error": "❌ No valid images provided. Check URLs or base64 encoding."}
60
-
61
- inputs = self.processor(text=text, images=pil_images, return_tensors="pt")
62
 
63
  print("πŸ–₯️ Running inference...")
64
  with torch.no_grad():
65
  outputs = self.model(**inputs)
66
 
67
- logits_per_image = outputs.logits_per_image
68
- probabilities = logits_per_image.softmax(dim=1)
69
-
70
- print("βœ… Inference complete!")
71
- return {"predictions": probabilities.tolist()}
 
2
  from transformers import CLIPProcessor, CLIPModel
3
  from PIL import Image
4
  import requests
 
5
  import io
6
 
7
  class EndpointHandler:
 
10
  self.model = CLIPModel.from_pretrained("dazpye/clip-image")
11
  self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image")
12
 
13
+ def _load_image(self, image_url):
14
+ """Simple image loader for URL images."""
15
  try:
16
+ print(f"🌐 Fetching image: {image_url}")
17
+ response = requests.get(image_url, timeout=5)
18
+ response.raise_for_status() # Raise error if status is not 200
19
+ return Image.open(io.BytesIO(response.content)).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  except Exception as e:
21
+ print(f"❌ Image loading failed: {e}")
 
22
  return None # Return None if image loading fails
23
 
24
  def __call__(self, data):
25
+ """Processes input and runs inference."""
26
+ print("πŸ“₯ Received input...")
27
 
28
+ text = data.get("inputs", {}).get("text", ["default text"])
29
+ image_urls = data.get("inputs", {}).get("images", [])
30
 
31
+ images = [self._load_image(url) for url in image_urls if url]
32
+ images = [img for img in images if img] # Remove failed images
33
 
34
+ if not images:
35
+ print("❌ No valid images provided.")
36
+ return {"error": "No valid images provided."}
37
 
38
+ inputs = self.processor(text=text, images=images, return_tensors="pt")
 
 
 
 
39
 
40
  print("πŸ–₯️ Running inference...")
41
  with torch.no_grad():
42
  outputs = self.model(**inputs)
43
 
44
+ return {"predictions": outputs.logits_per_image.softmax(dim=1).tolist()}