WolseyTheCat commited on
Commit
7b75353
·
verified ·
1 Parent(s): 9b8cbf2

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +66 -0
handler.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import base64
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ from transparent_background import Remover
6
+
7
+
8
+ class EndpointHandler:
9
+ def __init__(self):
10
+ # Initialize the remover model with desired settings.
11
+ self.remover = Remover(mode='fast')
12
+ # Warm up the model with a dummy image.
13
+ dummy = Image.new("RGB", (256, 256), "white")
14
+ _ = self.remover.process(dummy)
15
+
16
+ def __call__(self, request):
17
+ """
18
+ Expects a dictionary (JSON) with keys:
19
+ - "images": a list of base64-encoded image strings (e.g. "data:image/png;base64,...")
20
+ - "output_type": one of "rgba", "map", "green", "blur", or "overlay"
21
+ - "threshold": a float (0.0 to 1.0)
22
+ - "reverse": a boolean flag
23
+ Returns a dictionary with:
24
+ - "images": list of processed images (base64-encoded with data URI prefix)
25
+ - "processing_times": a string with individual and total processing times
26
+ """
27
+ # Get parameters from the request.
28
+ images_data = request.get("images", [])
29
+ output_type = request.get("output_type", "rgba")
30
+ threshold = request.get("threshold", 0.1)
31
+ reverse = request.get("reverse", False)
32
+
33
+ processed_results = []
34
+ times_list = []
35
+
36
+ global_start = time.time()
37
+
38
+ # Process up to 3 images
39
+ for idx, img_b64 in enumerate(images_data[:3]):
40
+ # Remove data URI prefix if present.
41
+ if img_b64.startswith("data:"):
42
+ img_b64 = img_b64.split(",")[1]
43
+ # Decode the image.
44
+ img_bytes = base64.b64decode(img_b64)
45
+ image = Image.open(BytesIO(img_bytes)).convert("RGB")
46
+
47
+ start_time = time.time()
48
+ result = self.remover.process(image, type=output_type, threshold=threshold, reverse=reverse)
49
+ elapsed = time.time() - start_time
50
+ times_list.append(f"Image {idx+1}: {elapsed:.2f} seconds")
51
+
52
+ # Convert the result to base64.
53
+ buffer = BytesIO()
54
+ result.save(buffer, format="PNG")
55
+ buffer.seek(0)
56
+ result_b64 = base64.b64encode(buffer.read()).decode("utf-8")
57
+ processed_results.append("data:image/png;base64," + result_b64)
58
+
59
+ total_time = time.time() - global_start
60
+ times_list.append(f"Total time: {total_time:.2f} seconds")
61
+ elapsed_str = "\n".join(times_list)
62
+
63
+ return {
64
+ "images": processed_results,
65
+ "processing_times": elapsed_str
66
+ }