dazpye commited on
Commit
24c40ff
Β·
verified Β·
1 Parent(s): ec42764

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +34 -19
handler.py CHANGED
@@ -1,47 +1,62 @@
1
  import torch
2
  from transformers import CLIPProcessor, CLIPModel
3
  from PIL import Image
 
4
  import base64
5
  import io
6
 
7
  class EndpointHandler:
8
  def __init__(self, model_dir=None): # AWS expects model_dir
9
- print("Loading model...")
10
  self.model = CLIPModel.from_pretrained("dazpye/clip-image")
11
  self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image")
12
 
13
- def _load_image(self, image_url):
14
- """Fetches an image and ensures it is fully loaded."""
15
  try:
16
- print(f"Fetching image from: {image_url}")
17
- response = requests.get(image_url, timeout=5)
18
- print(f"HTTP Status Code: {response.status_code}")
19
-
20
- if response.status_code == 200:
21
- image_bytes = io.BytesIO(response.content) # Convert to bytes
22
- return Image.open(image_bytes)
23
- else:
24
- print(f"❌ Failed to fetch image: HTTP {response.status_code}")
 
 
 
 
 
 
 
 
 
25
  except Exception as e:
26
- print(f"❌ Exception in image loading: {e}")
 
27
  return None # Return None if image loading fails
28
 
29
  def __call__(self, data):
30
  """Main inference function Hugging Face expects."""
31
- print("Processing input...")
32
-
 
 
 
33
  text = data.get("text", ["default caption"]) # Default text
34
  images = data.get("images", []) # List of images
35
-
36
  # Convert image URLs or base64 strings to PIL images
37
  pil_images = [self._load_image(img) for img in images if img]
 
38
 
39
  if not pil_images:
40
- return {"error": "No valid images provided."}
41
-
42
  inputs = self.processor(text=text, images=pil_images, return_tensors="pt")
43
 
44
- print("Running inference...")
45
  with torch.no_grad():
46
  outputs = self.model(**inputs)
47
 
 
1
  import torch
2
  from transformers import CLIPProcessor, CLIPModel
3
  from PIL import Image
4
+ import requests
5
  import base64
6
  import io
7
 
8
  class EndpointHandler:
9
  def __init__(self, model_dir=None): # AWS expects model_dir
10
+ print("πŸ”„ Loading model...")
11
  self.model = CLIPModel.from_pretrained("dazpye/clip-image")
12
  self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image")
13
 
14
+ def _load_image(self, image_data):
15
+ """Fetches an image from a URL or decodes a base64 image."""
16
  try:
17
+ if isinstance(image_data, str):
18
+ if image_data.startswith("http"):
19
+ # Fetch image from URL
20
+ print(f"🌐 Fetching image from: {image_data}")
21
+ response = requests.get(image_data, timeout=5)
22
+ print(f"βœ… HTTP Status Code: {response.status_code}")
23
+
24
+ if response.status_code == 200:
25
+ image_bytes = io.BytesIO(response.content)
26
+ return Image.open(image_bytes).convert("RGB")
27
+ else:
28
+ print(f"❌ Failed to fetch image: HTTP {response.status_code}")
29
+
30
+ else:
31
+ # Handle base64-encoded image
32
+ print("πŸ“Έ Decoding base64 image...")
33
+ return Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
34
+
35
  except Exception as e:
36
+ print(f"⚠️ Exception in image loading: {e}")
37
+
38
  return None # Return None if image loading fails
39
 
40
  def __call__(self, data):
41
  """Main inference function Hugging Face expects."""
42
+ print("πŸ“₯ Processing input...")
43
+
44
+ if "inputs" in data:
45
+ data = data["inputs"]
46
+
47
  text = data.get("text", ["default caption"]) # Default text
48
  images = data.get("images", []) # List of images
49
+
50
  # Convert image URLs or base64 strings to PIL images
51
  pil_images = [self._load_image(img) for img in images if img]
52
+ pil_images = [img for img in pil_images if img] # Remove None values
53
 
54
  if not pil_images:
55
+ return {"error": "❌ No valid images provided. Check URLs or base64 encoding."}
56
+
57
  inputs = self.processor(text=text, images=pil_images, return_tensors="pt")
58
 
59
+ print("πŸ–₯️ Running inference...")
60
  with torch.no_grad():
61
  outputs = self.model(**inputs)
62