Spaces:
Running
on
L40S
Running
on
L40S
| import os, sys | |
| sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}") | |
| import numpy as np | |
| from PIL import Image | |
| from rembg import remove, new_session | |
| from infer.utils import timing_decorator | |
| class Removebg(): | |
| def __init__(self, name="u2net"): | |
| self.session = new_session(name) | |
| def __call__(self, rgb_maybe, force=True): | |
| ''' | |
| args: | |
| rgb_maybe: PIL.Image, with RGB mode or RGBA mode | |
| force: bool, if input is RGBA mode, covert to RGB then remove bg | |
| return: | |
| rgba_img: PIL.Image, with RGBA mode | |
| ''' | |
| if rgb_maybe.mode == "RGBA": | |
| if force: | |
| rgb_maybe = rgb_maybe.convert("RGB") | |
| rgba_img = remove(rgb_maybe, session=self.session) | |
| else: | |
| rgba_img = rgb_maybe | |
| else: | |
| rgba_img = remove(rgb_maybe, session=self.session) | |
| rgba_img = white_out_background(rgba_img) | |
| rgba_img = preprocess(rgba_img) | |
| return rgba_img | |
| def white_out_background(pil_img): | |
| data = pil_img.getdata() | |
| new_data = [] | |
| for r, g, b, a in data: | |
| if a < 16: # background | |
| new_data.append((255, 255, 255, 0)) # full white color | |
| else: | |
| is_white = (r>235) and (g>235) and (b>235) | |
| new_r = 235 if is_white else r | |
| new_g = 235 if is_white else g | |
| new_b = 235 if is_white else b | |
| new_data.append((new_r, new_g, new_b, a)) | |
| pil_img.putdata(new_data) | |
| return pil_img | |
| def preprocess(rgba_img, size=(512,512), ratio=1.15): | |
| image = np.asarray(rgba_img) | |
| rgb, alpha = image[:,:,:3] / 255., image[:,:,3:] / 255. | |
| # crop | |
| coords = np.nonzero(alpha > 0.1) | |
| x_min, x_max = coords[0].min(), coords[0].max() | |
| y_min, y_max = coords[1].min(), coords[1].max() | |
| rgb = (rgb[x_min:x_max, y_min:y_max, :] * 255).astype("uint8") | |
| alpha = (alpha[x_min:x_max, y_min:y_max, 0] * 255).astype("uint8") | |
| # padding | |
| h, w = rgb.shape[:2] | |
| resize_side = int(max(h, w) * ratio) | |
| pad_h, pad_w = resize_side - h, resize_side - w | |
| start_h, start_w = pad_h // 2, pad_w // 2 | |
| new_rgb = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255 | |
| new_alpha = np.zeros((resize_side, resize_side), dtype=np.uint8) | |
| new_rgb[start_h:start_h + h, start_w:start_w + w] = rgb | |
| new_alpha[start_h:start_h + h, start_w:start_w + w] = alpha | |
| rgba_array = np.concatenate((new_rgb, new_alpha[:,:,None]), axis=-1) | |
| rgba_image = Image.fromarray(rgba_array, 'RGBA') | |
| rgba_image = rgba_image.resize(size) | |
| return rgba_image | |
| if __name__ == "__main__": | |
| import argparse | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--rgb_path", type=str, required=True) | |
| parser.add_argument("--output_rgba_path", type=str, required=True) | |
| parser.add_argument("--force", default=False, action="store_true") | |
| return parser.parse_args() | |
| args = get_args() | |
| rgb_maybe = Image.open(args.rgb_path) | |
| model = Removebg() | |
| rgba_pil = model(rgb_maybe, args.force) | |
| rgba_pil.save(args.output_rgba_path) | |