import time import base64 from io import BytesIO from PIL import Image from transparent_background import Remover class EndpointHandler: def __init__(self, path=""): self.path = path self.remover = Remover(mode='fast') # Warm up the model with a dummy image. dummy = Image.new("RGB", (256, 256), "white") _ = self.remover.process(dummy) def __call__(self, request): """ Expects a dictionary (JSON) with keys: - "images": a list of base64-encoded image strings (e.g. "data:image/png;base64,...") - "output_type": one of "rgba", "map", "green", "blur", "overlay" - "threshold": a float (0.0 to 1.0) - "reverse": a boolean flag Returns a dictionary with: - "images": list of processed images (base64-encoded with data URI prefix) - "processing_times": a string with individual and total processing times """ inputs = request.pop("inputs", request) images_data = inputs.get("images", []) output_type = inputs.get("output_type", "rgba") threshold = inputs.get("threshold", 0.0) reverse = inputs.get("reverse", False) processed_results = [] times_list = [] global_start = time.time() # Process up to 3 images for idx, img_b64 in enumerate(images_data[:3]): # Remove data URI prefix if present. if img_b64.startswith("data:"): img_b64 = img_b64.split(",")[1] # Decode the image. img_bytes = base64.b64decode(img_b64) image = Image.open(BytesIO(img_bytes)).convert("RGB") start_time = time.time() result = self.remover.process(image, type=output_type, threshold=threshold, reverse=reverse) elapsed = time.time() - start_time times_list.append(f"Image {idx+1}: {elapsed:.2f} seconds") # Convert the result to base64. buffer = BytesIO() result.save(buffer, format="PNG") buffer.seek(0) result_b64 = base64.b64encode(buffer.read()).decode("utf-8") processed_results.append("data:image/png;base64," + result_b64) total_time = time.time() - global_start times_list.append(f"Total time: {total_time:.2f} seconds") elapsed_str = "\n".join(times_list) return { "images": processed_results, "processing_times": elapsed_str }