transparent-bg / handler.py
WolseyTheCat's picture
Update handler.py
7db0491 verified
import time
import base64
from io import BytesIO
from PIL import Image
from transparent_background import Remover
class EndpointHandler:
def __init__(self, path=""):
self.path = path
self.remover = Remover(mode='fast')
# Warm up the model with a dummy image.
dummy = Image.new("RGB", (256, 256), "white")
_ = self.remover.process(dummy)
def __call__(self, request):
"""
Expects a dictionary (JSON) with keys:
- "images": a list of base64-encoded image strings (e.g. "data:image/png;base64,...")
- "output_type": one of "rgba", "map", "green", "blur", "overlay"
- "threshold": a float (0.0 to 1.0)
- "reverse": a boolean flag
Returns a dictionary with:
- "images": list of processed images (base64-encoded with data URI prefix)
- "processing_times": a string with individual and total processing times
"""
inputs = request.pop("inputs", request)
images_data = inputs.get("images", [])
output_type = inputs.get("output_type", "rgba")
threshold = inputs.get("threshold", 0.0)
reverse = inputs.get("reverse", False)
processed_results = []
times_list = []
global_start = time.time()
# Process up to 3 images
for idx, img_b64 in enumerate(images_data[:3]):
# Remove data URI prefix if present.
if img_b64.startswith("data:"):
img_b64 = img_b64.split(",")[1]
# Decode the image.
img_bytes = base64.b64decode(img_b64)
image = Image.open(BytesIO(img_bytes)).convert("RGB")
start_time = time.time()
result = self.remover.process(image, type=output_type, threshold=threshold, reverse=reverse)
elapsed = time.time() - start_time
times_list.append(f"Image {idx+1}: {elapsed:.2f} seconds")
# Convert the result to base64.
buffer = BytesIO()
result.save(buffer, format="PNG")
buffer.seek(0)
result_b64 = base64.b64encode(buffer.read()).decode("utf-8")
processed_results.append("data:image/png;base64," + result_b64)
total_time = time.time() - global_start
times_list.append(f"Total time: {total_time:.2f} seconds")
elapsed_str = "\n".join(times_list)
return {
"images": processed_results,
"processing_times": elapsed_str
}