File size: 2,496 Bytes
7b75353
 
 
 
 
 
 
da26f81
 
7b75353
 
 
 
 
 
 
 
 
da26f81
7b75353
 
 
 
 
 
7db0491
 
 
 
 
7b75353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import time
import base64
from io import BytesIO
from PIL import Image
from transparent_background import Remover

class EndpointHandler:
    def __init__(self, path=""):
        self.path = path
        self.remover = Remover(mode='fast')
        # Warm up the model with a dummy image.
        dummy = Image.new("RGB", (256, 256), "white")
        _ = self.remover.process(dummy)

    def __call__(self, request):
        """
        Expects a dictionary (JSON) with keys:
          - "images": a list of base64-encoded image strings (e.g. "data:image/png;base64,...")
          - "output_type": one of "rgba", "map", "green", "blur", "overlay"
          - "threshold": a float (0.0 to 1.0)
          - "reverse": a boolean flag
        Returns a dictionary with:
          - "images": list of processed images (base64-encoded with data URI prefix)
          - "processing_times": a string with individual and total processing times
        """
        inputs = request.pop("inputs", request)
        images_data = inputs.get("images", [])
        output_type = inputs.get("output_type", "rgba")
        threshold = inputs.get("threshold", 0.0)
        reverse = inputs.get("reverse", False)

        processed_results = []
        times_list = []
        global_start = time.time()

        # Process up to 3 images
        for idx, img_b64 in enumerate(images_data[:3]):
            # Remove data URI prefix if present.
            if img_b64.startswith("data:"):
                img_b64 = img_b64.split(",")[1]
            # Decode the image.
            img_bytes = base64.b64decode(img_b64)
            image = Image.open(BytesIO(img_bytes)).convert("RGB")

            start_time = time.time()
            result = self.remover.process(image, type=output_type, threshold=threshold, reverse=reverse)
            elapsed = time.time() - start_time
            times_list.append(f"Image {idx+1}: {elapsed:.2f} seconds")

            # Convert the result to base64.
            buffer = BytesIO()
            result.save(buffer, format="PNG")
            buffer.seek(0)
            result_b64 = base64.b64encode(buffer.read()).decode("utf-8")
            processed_results.append("data:image/png;base64," + result_b64)

        total_time = time.time() - global_start
        times_list.append(f"Total time: {total_time:.2f} seconds")
        elapsed_str = "\n".join(times_list)

        return {
            "images": processed_results,
            "processing_times": elapsed_str
        }