diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..00442d1d74440f721fe32cfa189fafbbded3caa3 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,44 @@ +FROM python:3.8 + +RUN mkdir /app +RUN mkdir /.cache/ +RUN mkdir /.cache/matplotlib +RUN mkdir /.cache/huggingface +RUN mkdir /.cache/huggingface/hub/ +RUN mkdir /.cache/torch/ +RUN mkdir /.config +RUN mkdir /.config/matplotlib/ + +RUN chmod -R 777 /.cache +RUN chmod -R 777 /.cache/matplotlib +RUN chmod -R 777 /.cache/huggingface/hub +RUN chmod -R 777 /.cache/torch +RUN chmod -R 777 /.config/ +RUN chmod -R 777 /.config/matplotlib +RUN chmod -R 777 /app + + +COPY lama_cleaner ./lama_cleaner +COPY ./app.py ./app.py + + +COPY app/yolov8x-seg.pt /app +COPY big-lama.pt /app +# COPY clickseg_pplnet.pt /app +COPY u2net.onnx /app +COPY u2net.onnx /tmp + +RUN chmod -R a+r /app/yolov8x-seg.pt +RUN chmod -R a+r /app/big-lama.pt +#RUN chmod -R a+r /app/clickseg_pplnet.pt +RUN chmod -R a+r /app/u2net.onnx +RUN chmod -R a+r /tmp/u2net.onnx + + +COPY ./requirements.txt ./requirements.txt +RUN pip install -r ./requirements.txt + +RUN --mount=type=secret,id=SECRET,mode=0444,required=true \ + git clone $(cat /run/secrets/SECRET) + +CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"] diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..f8aa99ac9b0e2a8eb9c5109f9f5754634271772c --- /dev/null +++ b/app.py @@ -0,0 +1,316 @@ +import base64 +import imghdr +import os + +import cv2 +import numpy as np +import torch +from ultralytics import YOLO +from ultralytics.yolo.utils.ops import scale_image +import asyncio +from fastapi import FastAPI, File, UploadFile, Request, Response +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +import uvicorn +# from mangum import Mangum +from argparse import ArgumentParser + +import lama_cleaner.server2 as server +from lama_cleaner.helper import ( + load_img, +) + +# os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/directory" + +app = FastAPI() + +# handler = Mangum(app) +origins = ["*"] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: + """ + Args: + image_numpy: numpy image + ext: image extension + Returns: + image bytes + """ + data = cv2.imencode( + f".{ext}", + image_numpy, + [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], + )[1].tobytes() + return data + + +def get_image_ext(img_bytes): + """ + Args: + img_bytes: image bytes + Returns: + image extension + """ + if not img_bytes: + raise ValueError("Empty input") + header = img_bytes[:32] + w = imghdr.what("", header) + if w is None: + w = "jpeg" + return w + + +def predict_on_image(model, img, conf, retina_masks): + """ + Args: + model: YOLOv8 model + img: image (C, H, W) + conf: confidence threshold + retina_masks: use retina masks or not + Returns: + boxes: box with xyxy format, (N, 4) + masks: masks, (N, H, W) + cls: class of masks, (N, ) + probs: confidence score, (N, 1) + """ + with torch.no_grad(): + result = model(img, conf=conf, retina_masks=retina_masks, scale=1)[0] + + boxes, masks, cls, probs = None, None, None, None + + if result.boxes.cls.size(0) > 0: + # detection + cls = result.boxes.cls.cpu().numpy().astype(np.int32) + probs = result.boxes.conf.cpu().numpy() # confidence score, (N, 1) + boxes = result.boxes.xyxy.cpu().numpy() # box with xyxy format, (N, 4) + + # segmentation + masks = result.masks.masks.cpu().numpy() # masks, (N, H, W) + masks = np.transpose(masks, (1, 2, 0)) # masks, (H, W, N) + # rescale masks to original image + masks = scale_image(masks.shape[:2], masks, result.masks.orig_shape) + masks = np.transpose(masks, (2, 0, 1)) # masks, (N, H, W) + + return boxes, masks, cls, probs + + +def overlay(image, mask, color, alpha, id, resize=None): + """Overlays a binary mask on an image. + + Args: + image: Image to be overlayed on. + mask: Binary mask to overlay. + color: Color to use for the mask. + alpha: Opacity of the mask. + id: id of the mask + resize: Resize the image to this size. If None, no resizing is performed. + + Returns: + The overlayed image. + """ + color = color[::-1] + colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) + colored_mask = np.moveaxis(colored_mask, 0, -1) + masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) + image_overlay = masked.filled() + + imgray = cv2.cvtColor(image_overlay, cv2.COLOR_BGR2GRAY) + + contour_thickness = 8 + _, thresh = cv2.threshold(imgray, 255, 255, 255) + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + imgray = cv2.cvtColor(imgray, cv2.COLOR_GRAY2BGR) + imgray = cv2.drawContours(imgray, contours, -1, (255, 255, 255), contour_thickness) + + imgray = np.where(imgray.any(-1, keepdims=True), (46, 36, 225), 0) + + if resize is not None: + image = cv2.resize(image.transpose(1, 2, 0), resize) + image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) + + return imgray + + +async def process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls): + """Process the mask of the image. + + Args: + idx: index of the mask + mask_i: mask of the image + boxes: box with xyxy format, (N, 4) + probs: confidence score, (N, 1) + yolo_model: YOLOv8 model + blank_image: blank image + cls: class of masks, (N, ) + + Returns: + dictionary_seg: dictionary of the mask of the image + """ + dictionary_seg = {} + maskwith_back = overlay(blank_image, mask_i, color=(255, 155, 155), alpha=0.5, id=idx) + + alpha = np.sum(maskwith_back, axis=-1) > 0 + alpha = np.uint8(alpha * 255) + maskwith_back = np.dstack((maskwith_back, alpha)) + + imgencode = await asyncio.get_running_loop().run_in_executor(None, cv2.imencode, '.png', maskwith_back) + mask = base64.b64encode(imgencode[1]).decode('utf-8') + + dictionary_seg["confi"] = f'{probs[idx] * 100:.2f}' + dictionary_seg["boxe"] = [int(item) for item in list(boxes[idx])] + dictionary_seg["mask"] = mask + dictionary_seg["cls"] = str(yolo_model.names[cls[idx]]) + + return dictionary_seg + + +@app.middleware("http") +async def check_auth_header(request: Request, call_next): + token = request.headers.get('Authorization') + if token != os.environ.get("SECRET"): + return JSONResponse(content={'error': 'Authorization header missing or incorrect.'}, status_code=403) + else: + response = await call_next(request) + return response + + +@app.post("/api/mask") +async def detect_mask(file: UploadFile = File()): + """ + Detects masks in an image uploaded via a POST request and returns a JSON response containing the details of the detected masks. + + Args: + None + + Parameters: + - file: a file object containing the input image + + Returns: + A JSON response containing the details of the detected masks: + - code: 200 if objects were detected, 500 if no objects were detected + - msg: a message indicating whether objects were detected or not + - data: a list of dictionaries, where each dictionary contains the following keys: + - confi: the confidence level of the detected object + - boxe: a list containing the coordinates of the bounding box of the detected object + - mask: the mask of the detected object encoded in base64 + - cls: the class of the detected object + + Raises: + 500: No objects detected + """ + file = await file.read() + + img, _ = load_img(file) + + # predict by YOLOv8 + boxes, masks, cls, probs = predict_on_image(yolo_model, img, conf=0.55, retina_masks=True) + + if boxes is None: + return {'code': 500, 'msg': 'No objects detected'} + + # overlay masks on original image + blank_image = np.zeros(img.shape, dtype=np.uint8) + + data = [] + + coroutines = [process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls) for idx, mask_i in + enumerate(masks)] + results = await asyncio.gather(*coroutines) + + for result in results: + data.append(result) + + return {'code': 200, 'msg': "object detected", 'data': data} + + +@app.post("/api/lama/paint") +async def paint(img: UploadFile = File(), mask: UploadFile = File()): + """ + Endpoint to process an image with a given mask using the server's process function. + + Route: '/api/lama/paint' + Method: POST + + Parameters: + img: The input image file (JPEG or PNG format). + mask: The mask file (JPEG or PNG format). + Returns: + A JSON object containing the processed image in base64 format under the "image" key. + """ + img = await img.read() + mask = await mask.read() + return {"image": server.process(img, mask)} + + +@app.post("/api/remove") +async def remove(img: UploadFile = File()): + x = await img.read() + return {"image": server.remove(x)} + +@app.post("/api/lama/model") +def switch_model(new_name: str): + return server.switch_model(new_name) + + +@app.get("/api/lama/model") +def current_model(): + return server.current_model() + + +@app.get("/api/lama/switchmode") +def get_is_disable_model_switch(): + return server.get_is_disable_model_switch() + + +@app.on_event("startup") +def init_data(): + model_device = "cpu" + global yolo_model + # TODO Update for local development + yolo_model = YOLO('yolov8x-seg.pt') + # yolo_model = YOLO('/app/yolov8x-seg.pt') + yolo_model.to(model_device) + print(f"YOLO model yolov8x-seg.pt loaded.") + server.initModel() + + +def create_app(args): + """ + Creates the FastAPI app and adds the endpoints. + + Args: + args: The arguments. + """ + uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument('--model_name', type=str, default='lama', help='Model name') + parser.add_argument('--host', type=str, default="0.0.0.0") + parser.add_argument('--port', type=int, default=5000) + parser.add_argument('--reload', type=bool, default=True) + parser.add_argument('--model_device', type=str, default='cpu', help='Model device') + parser.add_argument('--disable_model_switch', type=bool, default=False, help='Disable model switch') + parser.add_argument('--gui', type=bool, default=False, help='Enable GUI') + parser.add_argument('--cpu_offload', type=bool, default=False, help='Enable CPU offload') + parser.add_argument('--disable_nsfw', type=bool, default=False, help='Disable NSFW') + parser.add_argument('--enable_xformers', type=bool, default=False, help='Enable xformers') + parser.add_argument('--hf_access_token', type=str, default='', help='Hugging Face access token') + parser.add_argument('--local_files_only', type=bool, default=False, help='Enable local files only') + parser.add_argument('--no_half', type=bool, default=False, help='Disable half') + parser.add_argument('--sd_cpu_textencoder', type=bool, default=False, help='Enable CPU text encoder') + parser.add_argument('--sd_disable_nsfw', type=bool, default=False, help='Disable NSFW') + parser.add_argument('--sd_enable_xformers', type=bool, default=False, help='Enable xformers') + parser.add_argument('--sd_run_local', type=bool, default=False, help='Enable local files only') + + args = parser.parse_args() + create_app(args) diff --git a/app/big-lama.pt b/app/big-lama.pt new file mode 100644 index 0000000000000000000000000000000000000000..e1da5e6a1db6e0e2a155c1f638aa83ad8ce07dc2 --- /dev/null +++ b/app/big-lama.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:344c77bbcb158f17dd143070d1e789f38a66c04202311ae3a258ef66667a9ea9 +size 205669692 diff --git a/app/u2net.onnx b/app/u2net.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d5e2c4d942dc1e3d0a5cc5b194516e9ddd70a3ed --- /dev/null +++ b/app/u2net.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d10d2f3bb75ae3b6d527c77944fc5e7dcd94b29809d47a739a7a728a912b491 +size 175997641 diff --git a/app/yolov8x-seg.pt b/app/yolov8x-seg.pt new file mode 100644 index 0000000000000000000000000000000000000000..32ec037daef545f7d3858a80aa52d5159999757d --- /dev/null +++ b/app/yolov8x-seg.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d63cbfa5764867c0066bedfa43cf2dcd90a412a1de44b2e238c43978a9d28ea6 +size 144076467 diff --git a/lama_cleaner/__init__.py b/lama_cleaner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..399fff83a34162ab6a279fa906b0a38280de541c --- /dev/null +++ b/lama_cleaner/__init__.py @@ -0,0 +1,11 @@ +import warnings +warnings.simplefilter("ignore", UserWarning) + +from lama_cleaner.parse_args import parse_args + +def entry_point(): + args = parse_args() + # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers + # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18 + from lama_cleaner.server import main + main(args) diff --git a/lama_cleaner/__pycache__/__init__.cpython-38.pyc b/lama_cleaner/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2fdab55571b34818b5fe9a0f918ece0e611500a Binary files /dev/null and b/lama_cleaner/__pycache__/__init__.cpython-38.pyc differ diff --git a/lama_cleaner/__pycache__/const.cpython-38.pyc b/lama_cleaner/__pycache__/const.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..047ebd57abc6dd8067ad87ab76a891921f3ad558 Binary files /dev/null and b/lama_cleaner/__pycache__/const.cpython-38.pyc differ diff --git a/lama_cleaner/__pycache__/helper.cpython-38.pyc b/lama_cleaner/__pycache__/helper.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2302863aa42fc23868bffb1a718c84840d413d42 Binary files /dev/null and b/lama_cleaner/__pycache__/helper.cpython-38.pyc differ diff --git a/lama_cleaner/__pycache__/interactive_seg.cpython-38.pyc b/lama_cleaner/__pycache__/interactive_seg.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c84fe3b506e11ab7c94bd278db3468e6d60094d1 Binary files /dev/null and b/lama_cleaner/__pycache__/interactive_seg.cpython-38.pyc differ diff --git a/lama_cleaner/__pycache__/model_manager.cpython-38.pyc b/lama_cleaner/__pycache__/model_manager.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..125cc66dfed8def516d0a9ba747d66fe77a4b2e1 Binary files /dev/null and b/lama_cleaner/__pycache__/model_manager.cpython-38.pyc differ diff --git a/lama_cleaner/__pycache__/parse_args.cpython-38.pyc b/lama_cleaner/__pycache__/parse_args.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f79506386790cc422418cfea04cf615cc8a049f0 Binary files /dev/null and b/lama_cleaner/__pycache__/parse_args.cpython-38.pyc differ diff --git a/lama_cleaner/__pycache__/runtime.cpython-38.pyc b/lama_cleaner/__pycache__/runtime.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f696cf349166512f5f2a6c09980d066292e4b2bf Binary files /dev/null and b/lama_cleaner/__pycache__/runtime.cpython-38.pyc differ diff --git a/lama_cleaner/__pycache__/schema.cpython-38.pyc b/lama_cleaner/__pycache__/schema.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f0efc4ec9b2394648650d94a9f4a88d7cb46951 Binary files /dev/null and b/lama_cleaner/__pycache__/schema.cpython-38.pyc differ diff --git a/lama_cleaner/__pycache__/server2.cpython-38.pyc b/lama_cleaner/__pycache__/server2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b439ba920caf6d10af87415a58d3e285b130461 Binary files /dev/null and b/lama_cleaner/__pycache__/server2.cpython-38.pyc differ diff --git a/lama_cleaner/benchmark.py b/lama_cleaner/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a170ec872475c4dfec397c42057ad53376cf81 --- /dev/null +++ b/lama_cleaner/benchmark.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 + +import argparse +import os +import time + +import numpy as np +import nvidia_smi +import psutil +import torch + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import Config, HDStrategy, SDSampler + +try: + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) +except: + pass + +NUM_THREADS = str(4) + +os.environ["OMP_NUM_THREADS"] = NUM_THREADS +os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS +os.environ["MKL_NUM_THREADS"] = NUM_THREADS +os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS +os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS +if os.environ.get("CACHE_DIR"): + os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] + + +def run_model(model, size): + # RGB + image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8) + mask = np.random.randint(0, 255, size).astype(np.uint8) + + config = Config( + ldm_steps=2, + hd_strategy=HDStrategy.ORIGINAL, + hd_strategy_crop_margin=128, + hd_strategy_crop_trigger_size=128, + hd_strategy_resize_limit=128, + prompt="a fox is sitting on a bench", + sd_steps=5, + sd_sampler=SDSampler.ddim + ) + model(image, mask, config) + + +def benchmark(model, times: int, empty_cache: bool): + sizes = [(512, 512)] + + nvidia_smi.nvmlInit() + device_id = 0 + handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id) + + def format(metrics): + return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}" + + process = psutil.Process(os.getpid()) + # 每个 size 给出显存和内存占用的指标 + for size in sizes: + torch.cuda.empty_cache() + time_metrics = [] + cpu_metrics = [] + memory_metrics = [] + gpu_memory_metrics = [] + for _ in range(times): + start = time.time() + run_model(model, size) + torch.cuda.synchronize() + + # cpu_metrics.append(process.cpu_percent()) + time_metrics.append((time.time() - start) * 1000) + memory_metrics.append(process.memory_info().rss / 1024 / 1024) + gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024) + + print(f"size: {size}".center(80, "-")) + # print(f"cpu: {format(cpu_metrics)}") + print(f"latency: {format(time_metrics)}ms") + print(f"memory: {format(memory_metrics)} MB") + print(f"gpu memory: {format(gpu_memory_metrics)} MB") + + nvidia_smi.nvmlShutdown() + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--name") + parser.add_argument("--device", default="cuda", type=str) + parser.add_argument("--times", default=10, type=int) + parser.add_argument("--empty-cache", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args_parser() + device = torch.device(args.device) + model = ModelManager( + name=args.name, + device=device, + sd_run_local=True, + disable_nsfw=True, + sd_cpu_textencoder=True, + hf_access_token="123" + ) + benchmark(model, args.times, args.empty_cache) diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py new file mode 100644 index 0000000000000000000000000000000000000000..06245fb2825e35af94ef7b1f723c1ccd8769c7e2 --- /dev/null +++ b/lama_cleaner/const.py @@ -0,0 +1,68 @@ +import os + +DEFAULT_MODEL = "lama" +AVAILABLE_MODELS = [ + "lama", + "ldm", + "zits", + "mat", + "fcf", + "sd1.5", + "cv2", + "manga", + "sd2", + "paint_by_example" +] + +AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] +DEFAULT_DEVICE = 'cuda' + +NO_HALF_HELP = """ +Using full precision model. +If your generate result is always black or green, use this argument. (sd/paint_by_exmaple) +""" + +CPU_OFFLOAD_HELP = """ +Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example) +""" + +DISABLE_NSFW_HELP = """ +Disable NSFW checker. (sd/paint_by_example) +""" + +SD_CPU_TEXTENCODER_HELP = """ +Run Stable Diffusion text encoder model on CPU to save GPU memory. +""" + +LOCAL_FILES_ONLY_HELP = """ +Use local files only, not connect to Hugging Face server. (sd/paint_by_example) +""" + +ENABLE_XFORMERS_HELP = """ +Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example) +""" + +DEFAULT_MODEL_DIR = os.getenv( + "XDG_CACHE_HOME", + os.path.join(os.path.expanduser("~"), ".cache") +) +MODEL_DIR_HELP = """ +Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache +""" + +OUTPUT_DIR_HELP = """ +Only required when --input is directory. Result images will be saved to output directory automatically. +""" + +INPUT_HELP = """ +If input is image, it will be loaded by default. +If input is directory, you can browse and select image in file manager. +""" + +GUI_HELP = """ +Launch Lama Cleaner as desktop app +""" + +NO_GUI_AUTO_CLOSE_HELP = """ +Prevent backend auto close after the GUI window closed. +""" diff --git a/lama_cleaner/file_manager/__init__.py b/lama_cleaner/file_manager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a2499873fc1536ac1077241f2cd361a8d394156 --- /dev/null +++ b/lama_cleaner/file_manager/__init__.py @@ -0,0 +1 @@ +from .file_manager import FileManager diff --git a/lama_cleaner/file_manager/__pycache__/__init__.cpython-38.pyc b/lama_cleaner/file_manager/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daef4df4ead552f77059463f32fabd8f3717cb14 Binary files /dev/null and b/lama_cleaner/file_manager/__pycache__/__init__.cpython-38.pyc differ diff --git a/lama_cleaner/file_manager/__pycache__/file_manager.cpython-38.pyc b/lama_cleaner/file_manager/__pycache__/file_manager.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22ec283103f95349c6fae8fed035ee6a80bea028 Binary files /dev/null and b/lama_cleaner/file_manager/__pycache__/file_manager.cpython-38.pyc differ diff --git a/lama_cleaner/file_manager/__pycache__/storage_backends.cpython-38.pyc b/lama_cleaner/file_manager/__pycache__/storage_backends.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7730afb63a4000feae6cddc421d7ddbc84e5fa2f Binary files /dev/null and b/lama_cleaner/file_manager/__pycache__/storage_backends.cpython-38.pyc differ diff --git a/lama_cleaner/file_manager/__pycache__/utils.cpython-38.pyc b/lama_cleaner/file_manager/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd98cee04b5154888e0f23b8e28dba0a4e2ccd19 Binary files /dev/null and b/lama_cleaner/file_manager/__pycache__/utils.cpython-38.pyc differ diff --git a/lama_cleaner/file_manager/file_manager.py b/lama_cleaner/file_manager/file_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4443a4244da25d3f4fcac831afc1706b9685dc --- /dev/null +++ b/lama_cleaner/file_manager/file_manager.py @@ -0,0 +1,252 @@ +# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py +import os +import time +from datetime import datetime +from io import BytesIO +from pathlib import Path + +import cv2 +import numpy as np +from PIL import Image, ImageOps, PngImagePlugin +from loguru import logger +from watchdog.events import FileSystemEventHandler +from watchdog.observers import Observer + +LARGE_ENOUGH_NUMBER = 100 +PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024 ** 2) +from .storage_backends import FilesystemStorageBackend +from .utils import aspect_to_string, generate_filename, glob_img + + +class FileManager(FileSystemEventHandler): + def __init__(self, app=None): + self.app = app + self._default_root_directory = "media" + self._default_thumbnail_directory = "media" + self._default_root_url = "/" + self._default_thumbnail_root_url = "/" + self._default_format = "JPEG" + self.output_dir: Path = None + + if app is not None: + self.init_app(app) + + self.image_dir_filenames = [] + self.output_dir_filenames = [] + + self.image_dir_observer = None + self.output_dir_observer = None + + self.modified_time = { + "image": datetime.utcnow(), + "output": datetime.utcnow(), + } + + def start(self): + self.image_dir_filenames = self._media_names(self.root_directory) + self.output_dir_filenames = self._media_names(self.output_dir) + + logger.info(f"Start watching image directory: {self.root_directory}") + self.image_dir_observer = Observer() + self.image_dir_observer.schedule(self, self.root_directory, recursive=False) + self.image_dir_observer.start() + + logger.info(f"Start watching output directory: {self.output_dir}") + self.output_dir_observer = Observer() + self.output_dir_observer.schedule(self, self.output_dir, recursive=False) + self.output_dir_observer.start() + + def on_modified(self, event): + if not os.path.isdir(event.src_path): + return + if event.src_path == str(self.root_directory): + logger.info(f"Image directory {event.src_path} modified") + self.image_dir_filenames = self._media_names(self.root_directory) + self.modified_time['image'] = datetime.utcnow() + elif event.src_path == str(self.output_dir): + logger.info(f"Output directory {event.src_path} modified") + self.output_dir_filenames = self._media_names(self.output_dir) + self.modified_time['output'] = datetime.utcnow() + + def init_app(self, app): + if self.app is None: + self.app = app + app.thumbnail_instance = self + + if not hasattr(app, "extensions"): + app.extensions = {} + + if "thumbnail" in app.extensions: + raise RuntimeError("Flask-thumbnail extension already initialized") + + app.extensions["thumbnail"] = self + + app.config.setdefault("THUMBNAIL_MEDIA_ROOT", self._default_root_directory) + app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_ROOT", self._default_thumbnail_directory) + app.config.setdefault("THUMBNAIL_MEDIA_URL", self._default_root_url) + app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url) + app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format) + + def save_to_output_directory(self, image: np.ndarray, filename: str): + fp = Path(filename) + new_name = fp.stem + f"_{int(time.time())}" + fp.suffix + if image.shape[2] == 3: + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + elif image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA) + + cv2.imwrite(str(self.output_dir / new_name), image) + + @property + def root_directory(self): + path = self.app.config["THUMBNAIL_MEDIA_ROOT"] + + if os.path.isabs(path): + return path + else: + return os.path.join(self.app.root_path, path) + + @property + def thumbnail_directory(self): + path = self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] + + if os.path.isabs(path): + return path + else: + return os.path.join(self.app.root_path, path) + + @property + def root_url(self): + return self.app.config["THUMBNAIL_MEDIA_URL"] + + @property + def media_names(self): + # return self.image_dir_filenames + return self._media_names(self.root_directory) + + @property + def output_media_names(self): + return self._media_names(self.output_dir) + # return self.output_dir_filenames + + @staticmethod + def _media_names(directory: Path): + names = sorted([it.name for it in glob_img(directory)]) + res = [] + for name in names: + path = os.path.join(directory, name) + img = Image.open(path) + res.append({"name": name, "height": img.height, "width": img.width, "ctime": os.path.getctime(path)}) + return res + + @property + def thumbnail_url(self): + return self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_URL"] + + def get_thumbnail(self, directory: Path, original_filename: str, width, height, **options): + storage = FilesystemStorageBackend(self.app) + crop = options.get("crop", "fit") + background = options.get("background") + quality = options.get("quality", 90) + + original_path, original_filename = os.path.split(original_filename) + original_filepath = os.path.join(directory, original_path, original_filename) + image = Image.open(BytesIO(storage.read(original_filepath))) + + # keep ratio resize + if width is not None: + height = int(image.height * width / image.width) + else: + width = int(image.width * height / image.height) + + thumbnail_size = (width, height) + + thumbnail_filename = generate_filename( + original_filename, aspect_to_string(thumbnail_size), crop, background, quality + ) + + thumbnail_filepath = os.path.join( + self.thumbnail_directory, original_path, thumbnail_filename + ) + thumbnail_url = os.path.join(self.thumbnail_url, original_path, thumbnail_filename) + + if storage.exists(thumbnail_filepath): + return thumbnail_url, (width, height) + + try: + image.load() + except (IOError, OSError): + self.app.logger.warning("Thumbnail not load image: %s", original_filepath) + return thumbnail_url, (width, height) + + # get original image format + options["format"] = options.get("format", image.format) + + image = self._create_thumbnail(image, thumbnail_size, crop, background=background) + + raw_data = self.get_raw_data(image, **options) + storage.save(thumbnail_filepath, raw_data) + + return thumbnail_url, (width, height) + + def get_raw_data(self, image, **options): + data = { + "format": self._get_format(image, **options), + "quality": options.get("quality", 90), + } + + _file = BytesIO() + image.save(_file, **data) + return _file.getvalue() + + @staticmethod + def colormode(image, colormode="RGB"): + if colormode == "RGB" or colormode == "RGBA": + if image.mode == "RGBA": + return image + if image.mode == "LA": + return image.convert("RGBA") + return image.convert(colormode) + + if colormode == "GRAY": + return image.convert("L") + + return image.convert(colormode) + + @staticmethod + def background(original_image, color=0xFF): + size = (max(original_image.size),) * 2 + image = Image.new("L", size, color) + image.paste( + original_image, + tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))), + ) + + return image + + def _get_format(self, image, **options): + if options.get("format"): + return options.get("format") + if image.format: + return image.format + + return self.app.config["THUMBNAIL_DEFAULT_FORMAT"] + + def _create_thumbnail(self, image, size, crop="fit", background=None): + try: + resample = Image.Resampling.LANCZOS + except AttributeError: # pylint: disable=raise-missing-from + resample = Image.ANTIALIAS + + if crop == "fit": + image = ImageOps.fit(image, size, resample) + else: + image = image.copy() + image.thumbnail(size, resample=resample) + + if background is not None: + image = self.background(image) + + image = self.colormode(image) + + return image diff --git a/lama_cleaner/file_manager/storage_backends.py b/lama_cleaner/file_manager/storage_backends.py new file mode 100644 index 0000000000000000000000000000000000000000..3f453ade4dcd167856efe74f638e11dbc6145462 --- /dev/null +++ b/lama_cleaner/file_manager/storage_backends.py @@ -0,0 +1,46 @@ +# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py +import errno +import os +from abc import ABC, abstractmethod + + +class BaseStorageBackend(ABC): + def __init__(self, app=None): + self.app = app + + @abstractmethod + def read(self, filepath, mode="rb", **kwargs): + raise NotImplementedError + + @abstractmethod + def exists(self, filepath): + raise NotImplementedError + + @abstractmethod + def save(self, filepath, data): + raise NotImplementedError + + +class FilesystemStorageBackend(BaseStorageBackend): + def read(self, filepath, mode="rb", **kwargs): + with open(filepath, mode) as f: # pylint: disable=unspecified-encoding + return f.read() + + def exists(self, filepath): + return os.path.exists(filepath) + + def save(self, filepath, data): + directory = os.path.dirname(filepath) + + if not os.path.exists(directory): + try: + os.makedirs(directory) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + if not os.path.isdir(directory): + raise IOError("{} is not a directory".format(directory)) + + with open(filepath, "wb") as f: + f.write(data) diff --git a/lama_cleaner/file_manager/utils.py b/lama_cleaner/file_manager/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe650f55fb7223cc2f92177f9a5d7121cf73a0e --- /dev/null +++ b/lama_cleaner/file_manager/utils.py @@ -0,0 +1,66 @@ +# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py +import os +from pathlib import Path + +from typing import Union + + +def generate_filename(original_filename, *options): + name, ext = os.path.splitext(original_filename) + for v in options: + if v: + name += "_%s" % v + name += ext + + return name + + +def parse_size(size): + if isinstance(size, int): + # If the size parameter is a single number, assume square aspect. + return [size, size] + + if isinstance(size, (tuple, list)): + if len(size) == 1: + # If single value tuple/list is provided, exand it to two elements + return size + type(size)(size) + return size + + try: + thumbnail_size = [int(x) for x in size.lower().split("x", 1)] + except ValueError: + raise ValueError( # pylint: disable=raise-missing-from + "Bad thumbnail size format. Valid format is INTxINT." + ) + + if len(thumbnail_size) == 1: + # If the size parameter only contains a single integer, assume square aspect. + thumbnail_size.append(thumbnail_size[0]) + + return thumbnail_size + + +def aspect_to_string(size): + if isinstance(size, str): + return size + + return "x".join(map(str, size)) + + +IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'} + + +def glob_img(p: Union[Path, str], recursive: bool = False): + p = Path(p) + if p.is_file() and p.suffix in IMG_SUFFIX: + yield p + else: + if recursive: + files = Path(p).glob("**/*.*") + else: + files = Path(p).glob("*.*") + + for it in files: + if it.suffix not in IMG_SUFFIX: + continue + yield it diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..906fa99a8f2db3814be504464c18e05b947e9165 --- /dev/null +++ b/lama_cleaner/helper.py @@ -0,0 +1,218 @@ +import io +import os +import sys +from typing import List, Optional +from urllib.parse import urlparse + +import cv2 +import numpy as np +import torch +from PIL import Image, ImageOps +from loguru import logger +from torch.hub import download_url_to_file, get_dir + + +def get_cache_path_by_url(url): + parts = urlparse(url) + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + if not os.path.isdir(model_dir): + os.makedirs(model_dir) + filename = os.path.basename(parts.path) + cached_file = os.path.join(model_dir, filename) + return cached_file + + +def download_model(url): + cached_file = get_cache_path_by_url(url) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + download_url_to_file(url, cached_file, hash_prefix, progress=True) + return cached_file + + +def ceil_modulo(x, mod): + if x % mod == 0: + return x + return (x // mod + 1) * mod + + +def load_jit_model(url_or_path, device): + # if os.path.exists(url_or_path): + # model_path = url_or_path + # else: + # model_path = download_model(url_or_path) + model_path = os.getcwd() + logger.info(f"Load model from: {model_path}") + try: + model = torch.jit.load(model_path).to(device) + except: + logger.error( + f"Failed to load {model_path}, delete model and restart lama-cleaner" + ) + exit(-1) + model.eval() + return model + + +def load_model(model: torch.nn.Module, url_or_path, device): + if os.path.exists(url_or_path): + model_path = url_or_path + else: + model_path = download_model(url_or_path) + + try: + state_dict = torch.load(model_path, map_location='cpu') + model.load_state_dict(state_dict, strict=True) + model.to(device) + logger.info(f"Load model from: {model_path}") + except: + logger.error( + f"Failed to load {model_path}, delete model and restart lama-cleaner" + ) + exit(-1) + model.eval() + return model + + +def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: + data = cv2.imencode( + f".{ext}", + image_numpy, + [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], + )[1] + image_bytes = data.tobytes() + return image_bytes + + +def load_img(img_bytes, gray: bool = False): + alpha_channel = None + image = Image.open(io.BytesIO(img_bytes)) + try: + image = ImageOps.exif_transpose(image) + except: + pass + + if gray: + image = image.convert('L') + np_img = np.array(image) + else: + if image.mode == 'RGBA': + np_img = np.array(image) + alpha_channel = np_img[:, :, -1] + np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) + else: + image = image.convert('RGB') + np_img = np.array(image) + + return np_img, alpha_channel + + +def norm_img(np_img): + if len(np_img.shape) == 2: + np_img = np_img[:, :, np.newaxis] + np_img = np.transpose(np_img, (2, 0, 1)) + np_img = np_img.astype("float32") / 255 + return np_img + + +def resize_max_size( + np_img, size_limit: int, interpolation=cv2.INTER_CUBIC +) -> np.ndarray: + # Resize image's longer size to size_limit if longer size larger than size_limit + h, w = np_img.shape[:2] + if max(h, w) > size_limit: + ratio = size_limit / max(h, w) + new_w = int(w * ratio + 0.5) + new_h = int(h * ratio + 0.5) + return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation) + else: + return np_img + + +def pad_img_to_modulo( + img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None +): + """ + + Args: + img: [H, W, C] + mod: + square: 是否为正方形 + min_size: + + Returns: + + """ + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + height, width = img.shape[:2] + out_height = ceil_modulo(height, mod) + out_width = ceil_modulo(width, mod) + + if min_size is not None: + assert min_size % mod == 0 + out_width = max(min_size, out_width) + out_height = max(min_size, out_height) + + if square: + max_size = max(out_height, out_width) + out_height = max_size + out_width = max_size + + return np.pad( + img, + ((0, out_height - height), (0, out_width - width), (0, 0)), + mode="symmetric", + ) + + +def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: + """ + Args: + mask: (h, w, 1) 0~255 + + Returns: + + """ + height, width = mask.shape[:2] + _, thresh = cv2.threshold(mask, 127, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + boxes = [] + for cnt in contours: + x, y, w, h = cv2.boundingRect(cnt) + box = np.array([x, y, x + w, y + h]).astype(int) + + box[::2] = np.clip(box[::2], 0, width) + box[1::2] = np.clip(box[1::2], 0, height) + boxes.append(box) + + return boxes + + +def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]: + """ + Args: + mask: (h, w) 0~255 + + Returns: + + """ + _, thresh = cv2.threshold(mask, 127, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + max_area = 0 + max_index = -1 + for i, cnt in enumerate(contours): + area = cv2.contourArea(cnt) + if area > max_area: + max_area = area + max_index = i + + if max_index != -1: + new_mask = np.zeros_like(mask) + return cv2.drawContours(new_mask, contours, max_index, 255, -1) + else: + return mask diff --git a/lama_cleaner/interactive_seg.py b/lama_cleaner/interactive_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..50d0520ea437c92a2db7153fe467166900e006b9 --- /dev/null +++ b/lama_cleaner/interactive_seg.py @@ -0,0 +1,202 @@ +import os +from typing import Tuple, List + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from loguru import logger +from pydantic import BaseModel + +from lama_cleaner.helper import load_jit_model + + +class Click(BaseModel): + # [y, x] + coords: Tuple[float, float] + is_positive: bool + indx: int + + @property + def coords_and_indx(self): + return (*self.coords, self.indx) + + def scale(self, x_ratio: float, y_ratio: float) -> 'Click': + return Click( + coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio), + is_positive=self.is_positive, + indx=self.indx + ) + + +class ResizeTrans: + def __init__(self, size=480): + super().__init__() + self.crop_height = size + self.crop_width = size + + def transform(self, image_nd, clicks_lists): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + image_height, image_width = image_nd.shape[2:4] + self.image_height = image_height + self.image_width = image_width + image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode='bilinear', align_corners=True) + + y_ratio = self.crop_height / image_height + x_ratio = self.crop_width / image_width + + clicks_lists_resized = [] + for clicks_list in clicks_lists: + clicks_list_resized = [click.scale(y_ratio, x_ratio) for click in clicks_list] + clicks_lists_resized.append(clicks_list_resized) + + return image_nd_r, clicks_lists_resized + + def inv_transform(self, prob_map): + new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear', + align_corners=True) + + return new_prob_map + + +class ISPredictor(object): + def __init__( + self, + model, + device, + open_kernel_size: int, + dilate_kernel_size: int, + net_clicks_limit=None, + zoom_in=None, + infer_size=384, + ): + self.model = model + self.open_kernel_size = open_kernel_size + self.dilate_kernel_size = dilate_kernel_size + self.net_clicks_limit = net_clicks_limit + self.device = device + self.zoom_in = zoom_in + self.infer_size = infer_size + + # self.transforms = [zoom_in] if zoom_in is not None else [] + + def __call__(self, input_image: torch.Tensor, clicks: List[Click], prev_mask): + """ + + Args: + input_image: [1, 3, H, W] [0~1] + clicks: List[Click] + prev_mask: [1, 1, H, W] + + Returns: + + """ + transforms = [ResizeTrans(self.infer_size)] + input_image = torch.cat((input_image, prev_mask), dim=1) + + # image_nd resized to infer_size + for t in transforms: + image_nd, clicks_lists = t.transform(input_image, [clicks]) + + # image_nd.shape = [1, 4, 256, 256] + # points_nd.sha[e = [1, 2, 3] + # clicks_lists[0][0] Click 类 + points_nd = self.get_points_nd(clicks_lists) + pred_logits = self.model(image_nd, points_nd) + pred = torch.sigmoid(pred_logits) + pred = self.post_process(pred) + + prediction = F.interpolate(pred, mode='bilinear', align_corners=True, + size=image_nd.size()[2:]) + + for t in reversed(transforms): + prediction = t.inv_transform(prediction) + + # if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(): + # return self.get_prediction(clicker) + + return prediction.cpu().numpy()[0, 0] + + def post_process(self, pred: torch.Tensor) -> torch.Tensor: + pred_mask = pred.cpu().numpy()[0][0] + # morph_open to remove small noise + kernel_size = self.open_kernel_size + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1) + + # Why dilate: make region slightly larger to avoid missing some pixels, this generally works better + dilate_kernel_size = self.dilate_kernel_size + if dilate_kernel_size > 1: + kernel = cv2.getStructuringElement(cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size)) + pred_mask = cv2.dilate(pred_mask, kernel, 1) + return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0) + + def get_points_nd(self, clicks_lists): + total_clicks = [] + num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] + num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] + num_max_points = max(num_pos_clicks + num_neg_clicks) + if self.net_clicks_limit is not None: + num_max_points = min(self.net_clicks_limit, num_max_points) + num_max_points = max(1, num_max_points) + + for clicks_list in clicks_lists: + clicks_list = clicks_list[:self.net_clicks_limit] + pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] + pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] + + neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] + neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] + total_clicks.append(pos_clicks + neg_clicks) + + return torch.tensor(total_clicks, device=self.device) + + +INTERACTIVE_SEG_MODEL_URL = os.environ.get( + "INTERACTIVE_SEG_MODEL_URL", + "https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt", +) + + +class InteractiveSeg: + def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3): + device = torch.device('cpu') + model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device).eval() + self.predictor = ISPredictor(model, device, + infer_size=infer_size, + open_kernel_size=open_kernel_size, + dilate_kernel_size=dilate_kernel_size) + + def __call__(self, image, clicks, prev_mask=None): + """ + + Args: + image: [H,W,C] RGB + clicks: + + Returns: + + """ + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + image = torch.from_numpy((image / 255).transpose(2, 0, 1)).unsqueeze(0).float() + if prev_mask is None: + mask = torch.zeros_like(image[:, :1, :, :]) + else: + logger.info('InteractiveSeg run with prev_mask') + mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float() + + pred_probs = self.predictor(image, clicks, mask) + pred_mask = pred_probs > 0.5 + pred_mask = (pred_mask * 255).astype(np.uint8) + + # Find largest contour + # pred_mask = only_keep_largest_contour(pred_mask) + # To simplify frontend process, add mask brush color here + fg = pred_mask == 255 + bg = pred_mask != 255 + pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2BGRA) + # frontend brush color "ffcc00bb" + pred_mask[bg] = 0 + pred_mask[fg] = [255, 203, 0, int(255 * 0.73)] + pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_BGRA2RGBA) + return pred_mask diff --git a/lama_cleaner/model/__init__.py b/lama_cleaner/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lama_cleaner/model/__pycache__/__init__.cpython-38.pyc b/lama_cleaner/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2da660901ca327e95d6a50d11e9de5915c3d836 Binary files /dev/null and b/lama_cleaner/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/base.cpython-38.pyc b/lama_cleaner/model/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e40ad484f5de863b0b8e5f4442532062dc1d520 Binary files /dev/null and b/lama_cleaner/model/__pycache__/base.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/ddim_sampler.cpython-38.pyc b/lama_cleaner/model/__pycache__/ddim_sampler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb9bf9724e94df5e1ede7be469f119375fcb4974 Binary files /dev/null and b/lama_cleaner/model/__pycache__/ddim_sampler.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/fcf.cpython-38.pyc b/lama_cleaner/model/__pycache__/fcf.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85dd1137028669ef454afe82bba9c221a2ead1b6 Binary files /dev/null and b/lama_cleaner/model/__pycache__/fcf.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/lama.cpython-38.pyc b/lama_cleaner/model/__pycache__/lama.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63ba6c7492ffb5467f649f58cd38acd0ca8cb70f Binary files /dev/null and b/lama_cleaner/model/__pycache__/lama.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/ldm.cpython-38.pyc b/lama_cleaner/model/__pycache__/ldm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97989c0c2f77afe207f40943dfd2903538ab1521 Binary files /dev/null and b/lama_cleaner/model/__pycache__/ldm.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/manga.cpython-38.pyc b/lama_cleaner/model/__pycache__/manga.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30660becdc052273b691af9ed623192dff27cd6f Binary files /dev/null and b/lama_cleaner/model/__pycache__/manga.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/mat.cpython-38.pyc b/lama_cleaner/model/__pycache__/mat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..014a693539a86dd623118c667eb0b18739fb7425 Binary files /dev/null and b/lama_cleaner/model/__pycache__/mat.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/opencv2.cpython-38.pyc b/lama_cleaner/model/__pycache__/opencv2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cad8537443b60441c8f29a2d5b34583d04d3f1f4 Binary files /dev/null and b/lama_cleaner/model/__pycache__/opencv2.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/paint_by_example.cpython-38.pyc b/lama_cleaner/model/__pycache__/paint_by_example.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e79b8768eaef948010b6a4aceb68505a6f13eff1 Binary files /dev/null and b/lama_cleaner/model/__pycache__/paint_by_example.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/plms_sampler.cpython-38.pyc b/lama_cleaner/model/__pycache__/plms_sampler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6303dbd97483f73ab8260c052b53de90003541d8 Binary files /dev/null and b/lama_cleaner/model/__pycache__/plms_sampler.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/sd.cpython-38.pyc b/lama_cleaner/model/__pycache__/sd.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dbfdeb24014608ddec22675f43f53a23b751cdd Binary files /dev/null and b/lama_cleaner/model/__pycache__/sd.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/utils.cpython-38.pyc b/lama_cleaner/model/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81c3fb0762dd1e4dce03a18bf94313d35c1ab44f Binary files /dev/null and b/lama_cleaner/model/__pycache__/utils.cpython-38.pyc differ diff --git a/lama_cleaner/model/__pycache__/zits.cpython-38.pyc b/lama_cleaner/model/__pycache__/zits.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef1ed6803547ad9ba195d1eb613352982c72efd5 Binary files /dev/null and b/lama_cleaner/model/__pycache__/zits.cpython-38.pyc differ diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c4cb3be1744fe8e95fc9e40e23f031d99da4cf --- /dev/null +++ b/lama_cleaner/model/base.py @@ -0,0 +1,247 @@ +import abc +from typing import Optional + +import cv2 +import numpy as np +import torch +from loguru import logger + +from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo +from lama_cleaner.schema import Config, HDStrategy + + +class InpaintModel: + min_size: Optional[int] = None + pad_mod = 8 + pad_to_square = False + + def __init__(self, device, **kwargs): + """ + + Args: + device: + """ + self.device = device + self.init_model(device, **kwargs) + + @abc.abstractmethod + def init_model(self, device, **kwargs): + ... + + @staticmethod + @abc.abstractmethod + def is_downloaded() -> bool: + ... + + @abc.abstractmethod + def forward(self, image, mask, config: Config): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W, 1] 255 为 masks 区域 + return: BGR IMAGE + """ + ... + + def _pad_forward(self, image, mask, config: Config): + origin_height, origin_width = image.shape[:2] + pad_image = pad_img_to_modulo( + image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size + ) + pad_mask = pad_img_to_modulo( + mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size + ) + + logger.info(f"final forward pad size: {pad_image.shape}") + + result = self.forward(pad_image, pad_mask, config) + result = result[0:origin_height, 0:origin_width, :] + + result, image, mask = self.forward_post_process(result, image, mask, config) + + mask = mask[:, :, np.newaxis] + result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255)) + return result + + def forward_post_process(self, result, image, mask, config): + return result, image, mask + + @torch.no_grad() + def __call__(self, image, mask, config: Config): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + inpaint_result = None + logger.info(f"hd_strategy: {config.hd_strategy}") + if config.hd_strategy == HDStrategy.CROP: + if max(image.shape) > config.hd_strategy_crop_trigger_size: + logger.info(f"Run crop strategy") + boxes = boxes_from_mask(mask) + crop_result = [] + for box in boxes: + crop_image, crop_box = self._run_box(image, mask, box, config) + crop_result.append((crop_image, crop_box)) + + inpaint_result = image[:, :, ::-1] + for crop_image, crop_box in crop_result: + x1, y1, x2, y2 = crop_box + inpaint_result[y1:y2, x1:x2, :] = crop_image + + elif config.hd_strategy == HDStrategy.RESIZE: + if max(image.shape) > config.hd_strategy_resize_limit: + origin_size = image.shape[:2] + downsize_image = resize_max_size( + image, size_limit=config.hd_strategy_resize_limit + ) + downsize_mask = resize_max_size( + mask, size_limit=config.hd_strategy_resize_limit + ) + + logger.info( + f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}" + ) + inpaint_result = self._pad_forward( + downsize_image, downsize_mask, config + ) + + # only paste masked area result + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) + original_pixel_indices = mask < 127 + inpaint_result[original_pixel_indices] = image[:, :, ::-1][ + original_pixel_indices + ] + + if inpaint_result is None: + inpaint_result = self._pad_forward(image, mask, config) + + return inpaint_result + + def _crop_box(self, image, mask, box, config: Config): + """ + + Args: + image: [H, W, C] RGB + mask: [H, W, 1] + box: [left,top,right,bottom] + + Returns: + BGR IMAGE, (l, r, r, b) + """ + box_h = box[3] - box[1] + box_w = box[2] - box[0] + cx = (box[0] + box[2]) // 2 + cy = (box[1] + box[3]) // 2 + img_h, img_w = image.shape[:2] + + w = box_w + config.hd_strategy_crop_margin * 2 + h = box_h + config.hd_strategy_crop_margin * 2 + + _l = cx - w // 2 + _r = cx + w // 2 + _t = cy - h // 2 + _b = cy + h // 2 + + l = max(_l, 0) + r = min(_r, img_w) + t = max(_t, 0) + b = min(_b, img_h) + + # try to get more context when crop around image edge + if _l < 0: + r += abs(_l) + if _r > img_w: + l -= _r - img_w + if _t < 0: + b += abs(_t) + if _b > img_h: + t -= _b - img_h + + l = max(l, 0) + r = min(r, img_w) + t = max(t, 0) + b = min(b, img_h) + + crop_img = image[t:b, l:r, :] + crop_mask = mask[t:b, l:r] + + logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}") + + return crop_img, crop_mask, [l, t, r, b] + + def _calculate_cdf(self, histogram): + cdf = histogram.cumsum() + normalized_cdf = cdf / float(cdf.max()) + return normalized_cdf + + def _calculate_lookup(self, source_cdf, reference_cdf): + lookup_table = np.zeros(256) + lookup_val = 0 + for source_index, source_val in enumerate(source_cdf): + for reference_index, reference_val in enumerate(reference_cdf): + if reference_val >= source_val: + lookup_val = reference_index + break + lookup_table[source_index] = lookup_val + return lookup_table + + def _match_histograms(self, source, reference, mask): + transformed_channels = [] + for channel in range(source.shape[-1]): + source_channel = source[:, :, channel] + reference_channel = reference[:, :, channel] + + # only calculate histograms for non-masked parts + source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256]) + reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0, 256]) + + source_cdf = self._calculate_cdf(source_histogram) + reference_cdf = self._calculate_cdf(reference_histogram) + + lookup = self._calculate_lookup(source_cdf, reference_cdf) + + transformed_channels.append(cv2.LUT(source_channel, lookup)) + + result = cv2.merge(transformed_channels) + result = cv2.convertScaleAbs(result) + + return result + + def _apply_cropper(self, image, mask, config: Config): + img_h, img_w = image.shape[:2] + l, t, w, h = ( + config.croper_x, + config.croper_y, + config.croper_width, + config.croper_height, + ) + r = l + w + b = t + h + + l = max(l, 0) + r = min(r, img_w) + t = max(t, 0) + b = min(b, img_h) + + crop_img = image[t:b, l:r, :] + crop_mask = mask[t:b, l:r] + return crop_img, crop_mask, (l, t, r, b) + + def _run_box(self, image, mask, box, config: Config): + """ + + Args: + image: [H, W, C] RGB + mask: [H, W, 1] + box: [left,top,right,bottom] + + Returns: + BGR IMAGE + """ + crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config) + + return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b] diff --git a/lama_cleaner/model/ddim_sampler.py b/lama_cleaner/model/ddim_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..f4dc098e129ae45890de144ff9cb4bce3dd592bb --- /dev/null +++ b/lama_cleaner/model/ddim_sampler.py @@ -0,0 +1,192 @@ +import numpy as np +import torch +from loguru import logger +from tqdm import tqdm + +from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear"): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + # array([1]) + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000]) + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample(self, steps, conditioning, batch_size, shape): + self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + # samples: 1,3,128,128 + return self.ddim_sampling( + conditioning, + size, + quantize_denoised=False, + ddim_use_original_steps=False, + noise_dropout=0, + temperature=1.0, + ) + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + ddim_use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + ): + device = self.model.betas.device + b = shape[0] + img = torch.randn(shape, device=device, dtype=cond.dtype) + timesteps = ( + self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + ) + + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + logger.info(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + ) + img, _ = outs + + return img + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + ): + b, *_, device = *x.shape, x.device + e_t = self.model.apply_model(x, t, c) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: # 没用 + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: # 没用 + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/lama_cleaner/model/fcf.py b/lama_cleaner/model/fcf.py new file mode 100644 index 0000000000000000000000000000000000000000..99f74e25e0d437c656f65295ab17c3af3e762318 --- /dev/null +++ b/lama_cleaner/model/fcf.py @@ -0,0 +1,1212 @@ +import os +import random + +import cv2 +import numpy as np +import torch +import torch.fft as fft +import torch.nn.functional as F +from torch import conv2d, nn + +from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img, boxes_from_mask, resize_max_size +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.model.utils import setup_filter, _parse_scaling, _parse_padding, Conv2dLayer, FullyConnectedLayer, \ + MinibatchStdLayer, activation_funcs, conv2d_resample, bias_act, upsample2d, normalize_2nd_moment, downsample2d +from lama_cleaner.schema import Config + + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + assert isinstance(x, torch.Tensor) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + + +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + + +class EncoderEpilogue(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. + z_dim, # Output Latent (Z) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.cmap_dim = cmap_dim + self.resolution = resolution + self.img_channels = img_channels + self.architecture = architecture + + if architecture == 'skip': + self.fromrgb = Conv2dLayer(self.img_channels, in_channels, kernel_size=1, activation=activation) + self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, + num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None + self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, + conv_clamp=conv_clamp) + self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), z_dim, activation=activation) + self.dropout = torch.nn.Dropout(p=0.5) + + def forward(self, x, cmap, force_fp32=False): + _ = force_fp32 # unused + dtype = torch.float32 + memory_format = torch.contiguous_format + + # FromRGB. + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.mbstd is not None: + x = self.mbstd(x) + const_e = self.conv(x) + x = self.fc(const_e.flatten(1)) + x = self.dropout(x) + + # Conditioning. + if self.cmap_dim > 0: + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + assert x.dtype == dtype + return x, const_e + + +class EncoderBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + tmp_channels, # Number of intermediate channels. + out_channels, # Number of output channels. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + first_layer_idx, # Index of the first layer. + architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + freeze_layers=0, # Freeze-D: Number of layers to freeze. + ): + assert in_channels in [0, tmp_channels] + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.resolution = resolution + self.img_channels = img_channels + 1 + self.first_layer_idx = first_layer_idx + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.register_buffer('resample_filter', setup_filter(resample_filter)) + + self.num_layers = 0 + + def trainable_gen(): + while True: + layer_idx = self.first_layer_idx + self.num_layers + trainable = (layer_idx >= freeze_layers) + self.num_layers += 1 + yield trainable + + trainable_iter = trainable_gen() + + if in_channels == 0: + self.fromrgb = Conv2dLayer(self.img_channels, tmp_channels, kernel_size=1, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, + channels_last=self.channels_last) + + self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, + channels_last=self.channels_last) + + self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, + channels_last=self.channels_last) + + if architecture == 'resnet': + self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, + channels_last=self.channels_last) + + def forward(self, x, img, force_fp32=False): + # dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + dtype = torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + + # Input. + if x is not None: + x = x.to(dtype=dtype, memory_format=memory_format) + + # FromRGB. + if self.in_channels == 0: + img = img.to(dtype=dtype, memory_format=memory_format) + y = self.fromrgb(img) + x = x + y if x is not None else y + img = downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None + + # Main layers. + if self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x) + feat = x.clone() + x = self.conv1(x, gain=np.sqrt(0.5)) + x = y.add_(x) + else: + x = self.conv0(x) + feat = x.clone() + x = self.conv1(x) + + assert x.dtype == dtype + return x, img, feat + + +class EncoderNetwork(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + z_dim, # Input latent (Z) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture='orig', # Architecture: 'orig', 'skip', 'resnet'. + channel_base=16384, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=0, # Use FP16 for the N highest resolutions. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for EncoderEpilogue. + ): + super().__init__() + self.c_dim = c_dim + self.z_dim = z_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + use_fp16 = False + block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, + **mapping_kwargs) + self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs, + **common_kwargs) + + def forward(self, img, c, **block_kwargs): + x = None + feats = {} + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img, feat = block(x, img, **block_kwargs) + feats[res] = feat + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x, const_e = self.b4(x, cmap) + feats[4] = const_e + + B, _ = x.shape + z = torch.zeros((B, self.z_dim), requires_grad=False, dtype=x.dtype, + device=x.device) ## Noise for Co-Modulation + return x, z, feats + + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims + 1:]) + assert x.shape == shape + return x + + +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise=None, # Optional noise tensor to add to the output activations. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + padding=0, # Padding with respect to the upsampled image. + resample_filter=None, + # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate=True, # Apply weight demodulation? + flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1, 2, 3], + keepdim=True)) # max_Ikk + styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, + padding=padding, flip_weight=flip_weight) + if demodulate and noise is not None: + x = fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) + elif demodulate: + x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + # Execute as one fused op using grouped convolution. + batch_size = int(batch_size) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, + groups=batch_size, flip_weight=flip_weight) + x = x.reshape(batch_size, -1, *x.shape[2:]) + if noise is not None: + x = x.add_(noise) + return x + + +class SynthesisLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + use_noise=True, # Enable noise input? + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last=False, # Use channels_last format for the weights? + ): + super().__init__() + self.resolution = resolution + self.up = up + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.act_gain = activation_funcs[activation].def_gain + + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter( + torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + if use_noise: + self.register_buffer('noise_const', torch.randn([resolution, resolution])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + + def forward(self, x, w, noise_mode='none', fused_modconv=True, gain=1): + assert noise_mode in ['random', 'const', 'none'] + in_resolution = self.resolution // self.up + styles = self.affine(w) + + noise = None + if self.use_noise and noise_mode == 'random': + noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], + device=x.device) * self.noise_strength + if self.use_noise and noise_mode == 'const': + noise = self.noise_const * self.noise_strength + + flip_weight = (self.up == 1) # slightly faster + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, + padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, + fused_modconv=fused_modconv) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = F.leaky_relu(x, negative_slope=0.2, inplace=False) + if act_gain != 1: + x = x * act_gain + if act_clamp is not None: + x = x.clamp(-act_clamp, act_clamp) + return x + + +class ToRGBLayer(torch.nn.Module): + def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False): + super().__init__() + self.conv_clamp = conv_clamp + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter( + torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + + def forward(self, x, w, fused_modconv=True): + styles = self.affine(w) * self.weight_gain + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) + x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) + return x + + +class SynthesisForeword(torch.nn.Module): + def __init__(self, + z_dim, # Output Latent (Z) dimensionality. + resolution, # Resolution of this block. + in_channels, + img_channels, # Number of input color channels. + architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + + ): + super().__init__() + self.in_channels = in_channels + self.z_dim = z_dim + self.resolution = resolution + self.img_channels = img_channels + self.architecture = architecture + + self.fc = FullyConnectedLayer(self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation) + self.conv = SynthesisLayer(self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4) + + if architecture == 'skip': + self.torgb = ToRGBLayer(self.in_channels, self.img_channels, kernel_size=1, w_dim=(z_dim // 2) * 3) + + def forward(self, x, ws, feats, img, force_fp32=False): + _ = force_fp32 # unused + dtype = torch.float32 + memory_format = torch.contiguous_format + + x_global = x.clone() + # ToRGB. + x = self.fc(x) + x = x.view(-1, self.z_dim // 2, 4, 4) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + x_skip = feats[4].clone() + x = x + x_skip + + mod_vector = [] + mod_vector.append(ws[:, 0]) + mod_vector.append(x_global.clone()) + mod_vector = torch.cat(mod_vector, dim=1) + + x = self.conv(x, mod_vector) + + mod_vector = [] + mod_vector.append(ws[:, 2 * 2 - 3]) + mod_vector.append(x_global.clone()) + mod_vector = torch.cat(mod_vector, dim=1) + + if self.architecture == 'skip': + img = self.torgb(x, mod_vector) + img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format) + + assert x.dtype == dtype + return x, img + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=False), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + res = x * y.expand_as(x) + return res + + +class FourierUnit(nn.Module): + + def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', + spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'): + # bn_layer not used + super(FourierUnit, self).__init__() + self.groups = groups + + self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels=out_channels * 2, + kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) + self.relu = torch.nn.ReLU(inplace=False) + + # squeeze and excitation block + self.use_se = use_se + if use_se: + if se_kwargs is None: + se_kwargs = {} + self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) + + self.spatial_scale_factor = spatial_scale_factor + self.spatial_scale_mode = spatial_scale_mode + self.spectral_pos_encoding = spectral_pos_encoding + self.ffc3d = ffc3d + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + if self.spatial_scale_factor is not None: + orig_size = x.shape[-2:] + x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, + align_corners=False) + + r_size = x.size() + # (batch, c, h, w/2+1, 2) + fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) + ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view((batch, -1,) + ffted.size()[3:]) + + if self.spectral_pos_encoding: + height, width = ffted.shape[-2:] + coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted) + coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted) + ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) + + if self.use_se: + ffted = self.se(ffted) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(ffted) + + ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( + 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] + output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + if self.spatial_scale_factor is not None: + output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) + + return output + + +class SpectralTransform(nn.Module): + + def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.enable_lfu = enable_lfu + if stride == 2: + self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) + else: + self.downsample = nn.Identity() + + self.stride = stride + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels // + 2, kernel_size=1, groups=groups, bias=False), + # nn.BatchNorm2d(out_channels // 2), + nn.ReLU(inplace=True) + ) + self.fu = FourierUnit( + out_channels // 2, out_channels // 2, groups, **fu_kwargs) + if self.enable_lfu: + self.lfu = FourierUnit( + out_channels // 2, out_channels // 2, groups) + self.conv2 = torch.nn.Conv2d( + out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False) + + def forward(self, x): + + x = self.downsample(x) + x = self.conv1(x) + output = self.fu(x) + + if self.enable_lfu: + n, c, h, w = x.shape + split_no = 2 + split_s = h // split_no + xs = torch.cat(torch.split( + x[:, :c // 4], split_s, dim=-2), dim=1).contiguous() + xs = torch.cat(torch.split(xs, split_s, dim=-1), + dim=1).contiguous() + xs = self.lfu(xs) + xs = xs.repeat(1, 1, split_no, split_no).contiguous() + else: + xs = 0 + + output = self.conv2(x + output + xs) + + return output + + +class FFC(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, + ratio_gin, ratio_gout, stride=1, padding=0, + dilation=1, groups=1, bias=False, enable_lfu=True, + padding_type='reflect', gated=False, **spectral_kwargs): + super(FFC, self).__init__() + + assert stride == 1 or stride == 2, "Stride should be 1 or 2." + self.stride = stride + + in_cg = int(in_channels * ratio_gin) + in_cl = in_channels - in_cg + out_cg = int(out_channels * ratio_gout) + out_cl = out_channels - out_cg + # groups_g = 1 if groups == 1 else int(groups * ratio_gout) + # groups_l = 1 if groups == 1 else groups - groups_g + + self.ratio_gin = ratio_gin + self.ratio_gout = ratio_gout + self.global_in_num = in_cg + + module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d + self.convl2l = module(in_cl, out_cl, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d + self.convl2g = module(in_cl, out_cg, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d + self.convg2l = module(in_cg, out_cl, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform + self.convg2g = module( + in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs) + + self.gated = gated + module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + self.gate = module(in_channels, 2, 1) + + def forward(self, x, fname=None): + x_l, x_g = x if type(x) is tuple else (x, 0) + out_xl, out_xg = 0, 0 + + if self.gated: + total_input_parts = [x_l] + if torch.is_tensor(x_g): + total_input_parts.append(x_g) + total_input = torch.cat(total_input_parts, dim=1) + + gates = torch.sigmoid(self.gate(total_input)) + g2l_gate, l2g_gate = gates.chunk(2, dim=1) + else: + g2l_gate, l2g_gate = 1, 1 + + spec_x = self.convg2g(x_g) + + if self.ratio_gout != 1: + out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate + if self.ratio_gout != 0: + out_xg = self.convl2g(x_l) * l2g_gate + spec_x + + return out_xl, out_xg + + +class FFC_BN_ACT(nn.Module): + + def __init__(self, in_channels, out_channels, + kernel_size, ratio_gin, ratio_gout, + stride=1, padding=0, dilation=1, groups=1, bias=False, + norm_layer=nn.SyncBatchNorm, activation_layer=nn.Identity, + padding_type='reflect', + enable_lfu=True, **kwargs): + super(FFC_BN_ACT, self).__init__() + self.ffc = FFC(in_channels, out_channels, kernel_size, + ratio_gin, ratio_gout, stride, padding, dilation, + groups, bias, enable_lfu, padding_type=padding_type, **kwargs) + lnorm = nn.Identity if ratio_gout == 1 else norm_layer + gnorm = nn.Identity if ratio_gout == 0 else norm_layer + global_channels = int(out_channels * ratio_gout) + # self.bn_l = lnorm(out_channels - global_channels) + # self.bn_g = gnorm(global_channels) + + lact = nn.Identity if ratio_gout == 1 else activation_layer + gact = nn.Identity if ratio_gout == 0 else activation_layer + self.act_l = lact(inplace=True) + self.act_g = gact(inplace=True) + + def forward(self, x, fname=None): + x_l, x_g = self.ffc(x, fname=fname, ) + x_l = self.act_l(x_l) + x_g = self.act_g(x_g) + return x_l, x_g + + +class FFCResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1, + spatial_transform_kwargs=None, inline=False, ratio_gin=0.75, ratio_gout=0.75): + super().__init__() + self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + ratio_gin=ratio_gin, ratio_gout=ratio_gout) + self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + ratio_gin=ratio_gin, ratio_gout=ratio_gout) + self.inline = inline + + def forward(self, x, fname=None): + if self.inline: + x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:] + else: + x_l, x_g = x if type(x) is tuple else (x, 0) + + id_l, id_g = x_l, x_g + + x_l, x_g = self.conv1((x_l, x_g), fname=fname) + x_l, x_g = self.conv2((x_l, x_g), fname=fname) + + x_l, x_g = id_l + x_l, id_g + x_g + out = x_l, x_g + if self.inline: + out = torch.cat(out, dim=1) + return out + + +class ConcatTupleLayer(nn.Module): + def forward(self, x): + assert isinstance(x, tuple) + x_l, x_g = x + assert torch.is_tensor(x_l) or torch.is_tensor(x_g) + if not torch.is_tensor(x_g): + return x_l + return torch.cat(x, dim=1) + + +class FFCBlock(torch.nn.Module): + def __init__(self, + dim, # Number of output/input channels. + kernel_size, # Width and height of the convolution kernel. + padding, + ratio_gin=0.75, + ratio_gout=0.75, + activation='linear', # Activation function: 'relu', 'lrelu', etc. + ): + super().__init__() + if activation == 'linear': + self.activation = nn.Identity + else: + self.activation = nn.ReLU + self.padding = padding + self.kernel_size = kernel_size + self.ffc_block = FFCResnetBlock(dim=dim, + padding_type='reflect', + norm_layer=nn.SyncBatchNorm, + activation_layer=self.activation, + dilation=1, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout) + + self.concat_layer = ConcatTupleLayer() + + def forward(self, gen_ft, mask, fname=None): + x = gen_ft.float() + + x_l, x_g = x[:, :-self.ffc_block.conv1.ffc.global_in_num], x[:, -self.ffc_block.conv1.ffc.global_in_num:] + id_l, id_g = x_l, x_g + + x_l, x_g = self.ffc_block((x_l, x_g), fname=fname) + x_l, x_g = id_l + x_l, id_g + x_g + x = self.concat_layer((x_l, x_g)) + + return x + gen_ft.float() + + +class FFCSkipLayer(torch.nn.Module): + def __init__(self, + dim, # Number of input/output channels. + kernel_size=3, # Convolution kernel size. + ratio_gin=0.75, + ratio_gout=0.75, + ): + super().__init__() + self.padding = kernel_size // 2 + + self.ffc_act = FFCBlock(dim=dim, kernel_size=kernel_size, activation=nn.ReLU, + padding=self.padding, ratio_gin=ratio_gin, ratio_gout=ratio_gout) + + def forward(self, gen_ft, mask, fname=None): + x = self.ffc_act(gen_ft, mask, fname=fname) + return x + + +class SynthesisBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.register_buffer('resample_filter', setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1} + + if in_channels != 0 and resolution >= 8: + self.ffc_skip = nn.ModuleList() + for _ in range(self.res_ffc[resolution]): + self.ffc_skip.append(FFCSkipLayer(dim=out_channels)) + + if in_channels == 0: + self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) + + if in_channels != 0: + self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim * 3, resolution=resolution, up=2, + resample_filter=resample_filter, conv_clamp=conv_clamp, + channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim * 3, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim * 3, + conv_clamp=conv_clamp, channels_last=self.channels_last) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, mask, feats, img, ws, fname=None, force_fp32=False, fused_modconv=None, **layer_kwargs): + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + dtype = torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) + + x = x.to(dtype=dtype, memory_format=memory_format) + x_skip = feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs) + if len(self.ffc_skip) > 0: + mask = F.interpolate(mask, size=x_skip.shape[2:], ) + z = x + x_skip + for fres in self.ffc_skip: + z = fres(z, mask) + x = x + z + else: + x = x + x_skip + x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs) + if len(self.ffc_skip) > 0: + mask = F.interpolate(mask, size=x_skip.shape[2:], ) + z = x + x_skip + for fres in self.ffc_skip: + z = fres(z, mask) + x = x + z + else: + x = x + x_skip + x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs) + # ToRGB. + if img is not None: + img = upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + x = x.to(dtype=dtype) + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + z_dim, # Output Latent (Z) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base=16384, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=0, # Use FP16 for the N highest resolutions. + **block_kwargs, # Arguments for SynthesisBlock. + ): + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max), + z_dim=z_dim * 2, resolution=4) + + self.num_ws = self.img_resolution_log2 * 2 - 2 + for res in self.block_resolutions: + if res // 2 in channels_dict.keys(): + in_channels = channels_dict[res // 2] if res > 4 else 0 + else: + in_channels = min(channel_base // (res // 2), channel_max) + out_channels = channels_dict[res] + use_fp16 = (res >= fp16_resolution) + use_fp16 = False + is_last = (res == self.img_resolution) + block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, + img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) + setattr(self, f'b{res}', block) + + def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs): + + img = None + + x, img = self.foreword(x_global, ws, feats, img) + + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + mod_vector0 = [] + mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5]) + mod_vector0.append(x_global.clone()) + mod_vector0 = torch.cat(mod_vector0, dim=1) + + mod_vector1 = [] + mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4]) + mod_vector1.append(x_global.clone()) + mod_vector1 = torch.cat(mod_vector1, dim=1) + + mod_vector_rgb = [] + mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3]) + mod_vector_rgb.append(x_global.clone()) + mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1) + x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs) + return img + + +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers=8, # Number of mapping layers. + embed_features=None, # Label embedding dimensionality, None = same as w_dim. + layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): + # Embed, normalize, and concat inputs. + x = None + with torch.autograd.profiler.record_function('input'): + if self.z_dim > 0: + x = normalize_2nd_moment(z.to(torch.float32)) + if self.c_dim > 0: + y = normalize_2nd_moment(self.embed(c.to(torch.float32))) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f'fc{idx}') + x = layer(x) + + # Update moving average of W. + if self.w_avg_beta is not None and self.training and not skip_w_avg_update: + with torch.autograd.profiler.record_function('update_w_avg'): + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast. + if self.num_ws is not None: + with torch.autograd.profiler.record_function('broadcast'): + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + with torch.autograd.profiler.record_function('truncate'): + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + + +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + encoder_kwargs={}, # Arguments for EncoderNetwork. + mapping_kwargs={}, # Arguments for MappingNetwork. + synthesis_kwargs={}, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution, + img_channels=img_channels, **encoder_kwargs) + self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution, + img_channels=img_channels, **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs): + mask = img[:, -1].unsqueeze(1) + x_global, z, feats = self.encoder(img, c) + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) + img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs) + return img + + +FCF_MODEL_URL = os.environ.get( + "FCF_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth", +) + + +class FcF(InpaintModel): + min_size = 512 + pad_mod = 512 + pad_to_square = True + + def init_model(self, device, **kwargs): + seed = 0 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + kwargs = {'channel_base': 1 * 32768, 'channel_max': 512, 'num_fp16_res': 4, 'conv_clamp': 256} + G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3, + synthesis_kwargs=kwargs, encoder_kwargs=kwargs, mapping_kwargs={'num_layers': 2}) + self.model = load_model(G, FCF_MODEL_URL, device) + self.label = torch.zeros([1, self.model.c_dim], device=device) + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL)) + + @torch.no_grad() + def __call__(self, image, mask, config: Config): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + if image.shape[0] == 512 and image.shape[1] == 512: + return self._pad_forward(image, mask, config) + + boxes = boxes_from_mask(mask) + crop_result = [] + config.hd_strategy_crop_margin = 128 + for box in boxes: + crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config) + origin_size = crop_image.shape[:2] + resize_image = resize_max_size(crop_image, size_limit=512) + resize_mask = resize_max_size(crop_mask, size_limit=512) + inpaint_result = self._pad_forward(resize_image, resize_mask, config) + + # only paste masked area result + inpaint_result = cv2.resize(inpaint_result, (origin_size[1], origin_size[0]), interpolation=cv2.INTER_CUBIC) + + original_pixel_indices = crop_mask < 127 + inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][original_pixel_indices] + + crop_result.append((inpaint_result, crop_box)) + + inpaint_result = image[:, :, ::-1] + for crop_image, crop_box in crop_result: + x1, y1, x2, y2 = crop_box + inpaint_result[y1:y2, x1:x2, :] = crop_image + + return inpaint_result + + def forward(self, image, mask, config: Config): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W] mask area == 255 + return: BGR IMAGE + """ + + image = norm_img(image) # [0, 1] + image = image * 2 - 1 # [0, 1] -> [-1, 1] + mask = (mask > 120) * 255 + mask = norm_img(mask) + + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + + erased_img = image * (1 - mask) + input_image = torch.cat([0.5 - mask, erased_img], dim=1) + + output = self.model(input_image, self.label, truncation_psi=0.1, noise_mode='none') + output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8) + output = output[0].cpu().numpy() + cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return cur_res diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4ec535eb971d524c58d40af4fef53111cfedac --- /dev/null +++ b/lama_cleaner/model/lama.py @@ -0,0 +1,61 @@ +import os + +import cv2 +import numpy as np +import torch +from loguru import logger + +from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + +LAMA_MODEL_URL = os.environ.get( + "LAMA_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", +) + + +class LaMa(InpaintModel): + pad_mod = 8 + + def init_model(self, device, **kwargs): + if os.environ.get("LAMA_MODEL"): + model_path = os.environ.get("LAMA_MODEL") + if not os.path.exists(model_path): + raise FileNotFoundError( + f"lama torchscript model not found: {model_path}" + ) + else: + model_path = download_model(LAMA_MODEL_URL) + # TODO used to create a lambda docker image + # model_path = '../app/big-lama.pt' + logger.info(f"Load LaMa model from: {model_path}") + model = torch.jit.load(model_path, map_location="cpu") + model = model.to(device) + model.eval() + self.model = model + self.model_path = model_path + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL)) + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W] + return: BGR IMAGE + """ + image = norm_img(image) + mask = norm_img(mask) + + mask = (mask > 0) * 1 + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + + inpainted_image = self.model(image, mask) + + cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() + cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) + return cur_res \ No newline at end of file diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py new file mode 100644 index 0000000000000000000000000000000000000000..a13dfb92e1ba507e700d77d3082d703c1e9502f4 --- /dev/null +++ b/lama_cleaner/model/ldm.py @@ -0,0 +1,310 @@ +import os + +import numpy as np +import torch + +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.model.ddim_sampler import DDIMSampler +from lama_cleaner.model.plms_sampler import PLMSSampler +from lama_cleaner.schema import Config, LDMSampler + +torch.manual_seed(42) +import torch.nn as nn +from lama_cleaner.helper import ( + norm_img, + get_cache_path_by_url, + load_jit_model, +) +from lama_cleaner.model.utils import ( + make_beta_schedule, + timestep_embedding, +) + +LDM_ENCODE_MODEL_URL = os.environ.get( + "LDM_ENCODE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt", +) + +LDM_DECODE_MODEL_URL = os.environ.get( + "LDM_DECODE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt", +) + +LDM_DIFFUSION_MODEL_URL = os.environ.get( + "LDM_DIFFUSION_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt", +) + + +class DDPM(nn.Module): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + device, + timesteps=1000, + beta_schedule="linear", + linear_start=0.0015, + linear_end=0.0205, + cosine_s=0.008, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + parameterization="eps", # all assuming fixed variance schedules + use_positional_encodings=False, + ): + super().__init__() + self.device = device + self.parameterization = parameterization + self.use_positional_encodings = use_positional_encodings + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + self.register_schedule( + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + betas = make_beta_schedule( + self.device, + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + +class LatentDiffusion(DDPM): + def __init__( + self, + diffusion_model, + device, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + scale_factor=1.0, + scale_by_std=False, + *args, + **kwargs, + ): + self.num_timesteps_cond = 1 + self.scale_by_std = scale_by_std + super().__init__(device, *args, **kwargs) + self.diffusion_model = diffusion_model + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.num_downs = 2 + self.scale_factor = scale_factor + + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def apply_model(self, x_noisy, t, cond): + # x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128 + t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False) + x_recon = self.diffusion_model(x_noisy, t_emb, cond) + return x_recon + + +class LDM(InpaintModel): + pad_mod = 32 + + def __init__(self, device, fp16: bool = True, **kwargs): + self.fp16 = fp16 + super().__init__(device) + self.device = device + + def init_model(self, device, **kwargs): + self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device) + self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device) + self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device) + if self.fp16 and "cuda" in str(device): + self.diffusion_model = self.diffusion_model.half() + self.cond_stage_model_decode = self.cond_stage_model_decode.half() + self.cond_stage_model_encode = self.cond_stage_model_encode.half() + + self.model = LatentDiffusion(self.diffusion_model, device) + + @staticmethod + def is_downloaded() -> bool: + model_paths = [ + get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL), + get_cache_path_by_url(LDM_DECODE_MODEL_URL), + get_cache_path_by_url(LDM_ENCODE_MODEL_URL), + ] + return all([os.path.exists(it) for it in model_paths]) + + @torch.cuda.amp.autocast() + def forward(self, image, mask, config: Config): + """ + image: [H, W, C] RGB + mask: [H, W, 1] + return: BGR IMAGE + """ + # image [1,3,512,512] float32 + # mask: [1,1,512,512] float32 + # masked_image: [1,3,512,512] float32 + if config.ldm_sampler == LDMSampler.ddim: + sampler = DDIMSampler(self.model) + elif config.ldm_sampler == LDMSampler.plms: + sampler = PLMSSampler(self.model) + else: + raise ValueError() + + steps = config.ldm_steps + image = norm_img(image) + mask = norm_img(mask) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + masked_image = (1 - mask) * image + + mask = self._norm(mask) + masked_image = self._norm(masked_image) + + c = self.cond_stage_model_encode(masked_image) + torch.cuda.empty_cache() + + cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128 + c = torch.cat((c, cc), dim=1) # 1,4,128,128 + + shape = (c.shape[1] - 1,) + c.shape[2:] + samples_ddim = sampler.sample( + steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape + ) + torch.cuda.empty_cache() + x_samples_ddim = self.cond_stage_model_decode( + samples_ddim + ) # samples_ddim: 1, 3, 128, 128 float32 + torch.cuda.empty_cache() + + # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0) + inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + # inpainted = (1 - mask) * image + mask * predicted_image + inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 + inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1] + return inpainted_image + + def _norm(self, tensor): + return tensor * 2.0 - 1.0 diff --git a/lama_cleaner/model/manga.py b/lama_cleaner/model/manga.py new file mode 100644 index 0000000000000000000000000000000000000000..180350f0effabdd8a3dd05576724e6d43a1dfca0 --- /dev/null +++ b/lama_cleaner/model/manga.py @@ -0,0 +1,130 @@ +import os +import random +import time + +import cv2 +import numpy as np +import torch +from loguru import logger + +from lama_cleaner.helper import get_cache_path_by_url, load_jit_model +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + +# def norm(np_img): +# return np_img / 255 * 2 - 1.0 +# +# +# @torch.no_grad() +# def run(): +# name = 'manga_1080x740.jpg' +# img_p = f'/Users/qing/code/github/MangaInpainting/examples/test/imgs/{name}' +# mask_p = f'/Users/qing/code/github/MangaInpainting/examples/test/masks/mask_{name}' +# erika_model = torch.jit.load('erika.jit') +# manga_inpaintor_model = torch.jit.load('manga_inpaintor.jit') +# +# img = cv2.imread(img_p) +# gray_img = cv2.imread(img_p, cv2.IMREAD_GRAYSCALE) +# mask = cv2.imread(mask_p, cv2.IMREAD_GRAYSCALE) +# +# kernel = np.ones((9, 9), dtype=np.uint8) +# mask = cv2.dilate(mask, kernel, 2) +# # cv2.imwrite("mask.jpg", mask) +# # cv2.imshow('dilated_mask', cv2.hconcat([mask, dilated_mask])) +# # cv2.waitKey(0) +# # exit() +# +# # img = pad(img) +# gray_img = pad(gray_img).astype(np.float32) +# mask = pad(mask) +# +# # pad_mod = 16 +# import time +# start = time.time() +# y = erika_model(torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :])) +# y = torch.clamp(y, 0, 255) +# lines = y.cpu().numpy() +# print(f"erika_model time: {time.time() - start}") +# +# cv2.imwrite('lines.png', lines[0][0]) +# +# start = time.time() +# masks = torch.from_numpy(mask[np.newaxis, np.newaxis, :, :]) +# masks = torch.where(masks > 0.5, torch.tensor(1.0), torch.tensor(0.0)) +# noise = torch.randn_like(masks) +# +# images = torch.from_numpy(norm(gray_img)[np.newaxis, np.newaxis, :, :]) +# lines = torch.from_numpy(norm(lines)) +# +# outputs = manga_inpaintor_model(images, lines, masks, noise) +# print(f"manga_inpaintor_model time: {time.time() - start}") +# +# outputs_merged = (outputs * masks) + (images * (1 - masks)) +# outputs_merged = outputs_merged * 127.5 + 127.5 +# outputs_merged = outputs_merged.permute(0, 2, 3, 1)[0].detach().cpu().numpy().astype(np.uint8) +# cv2.imwrite(f'output_{name}', outputs_merged) + + +MANGA_INPAINTOR_MODEL_URL = os.environ.get( + "MANGA_INPAINTOR_MODEL_URL", + "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit" +) +MANGA_LINE_MODEL_URL = os.environ.get( + "MANGA_LINE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/manga/erika.jit" +) + + +class Manga(InpaintModel): + pad_mod = 16 + + def init_model(self, device, **kwargs): + self.inpaintor_model = load_jit_model(MANGA_INPAINTOR_MODEL_URL, device) + self.line_model = load_jit_model(MANGA_LINE_MODEL_URL, device) + self.seed = 42 + + @staticmethod + def is_downloaded() -> bool: + model_paths = [ + get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL), + get_cache_path_by_url(MANGA_LINE_MODEL_URL), + ] + return all([os.path.exists(it) for it in model_paths]) + + def forward(self, image, mask, config: Config): + """ + image: [H, W, C] RGB + mask: [H, W, 1] + return: BGR IMAGE + """ + seed = self.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + gray_img = torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)).to(self.device) + start = time.time() + lines = self.line_model(gray_img) + torch.cuda.empty_cache() + lines = torch.clamp(lines, 0, 255) + logger.info(f"erika_model time: {time.time() - start}") + + mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device) + mask = mask.permute(0, 3, 1, 2) + mask = torch.where(mask > 0.5, 1.0, 0.0) + noise = torch.randn_like(mask) + ones = torch.ones_like(mask) + + gray_img = gray_img / 255 * 2 - 1.0 + lines = lines / 255 * 2 - 1.0 + + start = time.time() + inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones) + logger.info(f"image_inpaintor_model time: {time.time() - start}") + + cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() + cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8) + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR) + return cur_res diff --git a/lama_cleaner/model/mat.py b/lama_cleaner/model/mat.py new file mode 100644 index 0000000000000000000000000000000000000000..67020bcad43cc7063f409fd5b14ca920c3b2a9cb --- /dev/null +++ b/lama_cleaner/model/mat.py @@ -0,0 +1,1444 @@ +import os +import random + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.model.utils import setup_filter, Conv2dLayer, FullyConnectedLayer, conv2d_resample, bias_act, \ + upsample2d, activation_funcs, MinibatchStdLayer, to_2tuple, normalize_2nd_moment +from lama_cleaner.schema import Config + + +class ModulatedConv2d(nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + style_dim, # dimension of the style code + demodulate=True, # perfrom demodulation + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + ): + super().__init__() + self.demodulate = demodulate + + self.weight = torch.nn.Parameter(torch.randn([1, out_channels, in_channels, kernel_size, kernel_size])) + self.out_channels = out_channels + self.kernel_size = kernel_size + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.padding = self.kernel_size // 2 + self.up = up + self.down = down + self.register_buffer('resample_filter', setup_filter(resample_filter)) + self.conv_clamp = conv_clamp + + self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1) + + def forward(self, x, style): + batch, in_channels, height, width = x.shape + style = self.affine(style).view(batch, 1, in_channels, 1, 1) + weight = self.weight * self.weight_gain * style + + if self.demodulate: + decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt() + weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1) + + weight = weight.view(batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size) + x = x.view(1, batch * in_channels, height, width) + x = conv2d_resample(x=x, w=weight, f=self.resample_filter, up=self.up, down=self.down, + padding=self.padding, groups=batch) + out = x.view(batch, self.out_channels, *x.shape[2:]) + + return out + + +class StyleConv(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + style_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + use_noise=False, # Enable noise input? + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + demodulate=True, # perform demodulation + ): + super().__init__() + + self.conv = ModulatedConv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + style_dim=style_dim, + demodulate=demodulate, + up=up, + resample_filter=resample_filter, + conv_clamp=conv_clamp) + + self.use_noise = use_noise + self.resolution = resolution + if use_noise: + self.register_buffer('noise_const', torch.randn([resolution, resolution])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.activation = activation + self.act_gain = activation_funcs[activation].def_gain + self.conv_clamp = conv_clamp + + def forward(self, x, style, noise_mode='random', gain=1): + x = self.conv(x, style) + + assert noise_mode in ['random', 'const', 'none'] + + if self.use_noise: + if noise_mode == 'random': + xh, xw = x.size()[-2:] + noise = torch.randn([x.shape[0], 1, xh, xw], device=x.device) \ + * self.noise_strength + if noise_mode == 'const': + noise = self.noise_const * self.noise_strength + x = x + noise + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp) + + return out + + +class ToRGB(torch.nn.Module): + def __init__(self, + in_channels, + out_channels, + style_dim, + kernel_size=1, + resample_filter=[1, 3, 3, 1], + conv_clamp=None, + demodulate=False): + super().__init__() + + self.conv = ModulatedConv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + style_dim=style_dim, + demodulate=demodulate, + resample_filter=resample_filter, + conv_clamp=conv_clamp) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.register_buffer('resample_filter', setup_filter(resample_filter)) + self.conv_clamp = conv_clamp + + def forward(self, x, style, skip=None): + x = self.conv(x, style) + out = bias_act(x, self.bias, clamp=self.conv_clamp) + + if skip is not None: + if skip.shape != out.shape: + skip = upsample2d(skip, self.resample_filter) + out = out + skip + + return out + + +def get_style_code(a, b): + return torch.cat([a, b], dim=1) + + +class DecBlockFirst(nn.Module): + def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): + super().__init__() + self.fc = FullyConnectedLayer(in_features=in_channels * 2, + out_features=in_channels * 4 ** 2, + activation=activation) + self.conv = StyleConv(in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=4, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB(in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, ws, gs, E_features, noise_mode='random'): + x = self.fc(x).view(x.shape[0], -1, 4, 4) + x = x + E_features[2] + style = get_style_code(ws[:, 0], gs) + x = self.conv(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, 1], gs) + img = self.toRGB(x, style, skip=None) + + return x, img + + +class DecBlockFirstV2(nn.Module): + def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): + super().__init__() + self.conv0 = Conv2dLayer(in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = StyleConv(in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=4, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB(in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, ws, gs, E_features, noise_mode='random'): + # x = self.fc(x).view(x.shape[0], -1, 4, 4) + x = self.conv0(x) + x = x + E_features[2] + style = get_style_code(ws[:, 0], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, 1], gs) + img = self.toRGB(x, style, skip=None) + + return x, img + + +class DecBlock(nn.Module): + def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, + img_channels): # res = 2, ..., resolution_log2 + super().__init__() + self.res = res + + self.conv0 = StyleConv(in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv(in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB(in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, img, ws, gs, E_features, noise_mode='random'): + style = get_style_code(ws[:, self.res * 2 - 5], gs) + x = self.conv0(x, style, noise_mode=noise_mode) + x = x + E_features[self.res] + style = get_style_code(ws[:, self.res * 2 - 4], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, self.res * 2 - 3], gs) + img = self.toRGB(x, style, skip=img) + + return x, img + + +class MappingNet(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers=8, # Number of mapping layers. + embed_features=None, # Label embedding dimensionality, None = same as w_dim. + layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): + # Embed, normalize, and concat inputs. + x = None + with torch.autograd.profiler.record_function('input'): + if self.z_dim > 0: + x = normalize_2nd_moment(z.to(torch.float32)) + if self.c_dim > 0: + y = normalize_2nd_moment(self.embed(c.to(torch.float32))) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f'fc{idx}') + x = layer(x) + + # Update moving average of W. + if self.w_avg_beta is not None and self.training and not skip_w_avg_update: + with torch.autograd.profiler.record_function('update_w_avg'): + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast. + if self.num_ws is not None: + with torch.autograd.profiler.record_function('broadcast'): + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + with torch.autograd.profiler.record_function('truncate'): + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + + return x + + +class DisFromRGB(nn.Module): + def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2 + super().__init__() + self.conv = Conv2dLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + activation=activation, + ) + + def forward(self, x): + return self.conv(x) + + +class DisBlock(nn.Module): + def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2 + super().__init__() + self.conv0 = Conv2dLayer(in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = Conv2dLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + down=2, + activation=activation, + ) + self.skip = Conv2dLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + down=2, + bias=False, + ) + + def forward(self, x): + skip = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x) + x = self.conv1(x, gain=np.sqrt(0.5)) + out = skip + x + + return out + + +class Discriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + channel_decay=1, + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + activation='lrelu', + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + resolution_log2 = int(np.log2(img_resolution)) + assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + self.resolution_log2 = resolution_log2 + + def nf(stage): + return np.clip(int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max) + + if cmap_dim == None: + cmap_dim = nf(2) + if c_dim == 0: + cmap_dim = 0 + self.cmap_dim = cmap_dim + + if c_dim > 0: + self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None) + + Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)] + for res in range(resolution_log2, 2, -1): + Dis.append(DisBlock(nf(res), nf(res - 1), activation)) + + if mbstd_num_channels > 0: + Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels)) + Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation)) + self.Dis = nn.Sequential(*Dis) + + self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation) + self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) + + def forward(self, images_in, masks_in, c): + x = torch.cat([masks_in - 0.5, images_in], dim=1) + x = self.Dis(x) + x = self.fc1(self.fc0(x.flatten(start_dim=1))) + + if self.c_dim > 0: + cmap = self.mapping(None, c) + + if self.cmap_dim > 0: + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + return x + + +def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512): + NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512} + return NF[2 ** stage] + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = FullyConnectedLayer(in_features=in_features, out_features=hidden_features, activation='lrelu') + self.fc2 = FullyConnectedLayer(in_features=hidden_features, out_features=out_features) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + # B = windows.shape[0] / (H * W / window_size / window_size) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class Conv2dLayerPartial(nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias=True, # Apply additive bias before the activation function? + activation='linear', # Activation function: 'relu', 'lrelu', etc. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + trainable=True, # Update the weights of this layer during training? + ): + super().__init__() + self.conv = Conv2dLayer(in_channels, out_channels, kernel_size, bias, activation, up, down, resample_filter, + conv_clamp, trainable) + + self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size) + self.slide_winsize = kernel_size ** 2 + self.stride = down + self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0 + + def forward(self, x, mask=None): + if mask is not None: + with torch.no_grad(): + if self.weight_maskUpdater.type() != x.type(): + self.weight_maskUpdater = self.weight_maskUpdater.to(x) + update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, + padding=self.padding) + mask_ratio = self.slide_winsize / (update_mask + 1e-8) + update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1 + mask_ratio = torch.mul(mask_ratio, update_mask) + x = self.conv(x) + x = torch.mul(x, mask_ratio) + return x, update_mask + else: + x = self.conv(x) + return x, None + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, down_ratio=1, qkv_bias=True, qk_scale=None, attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = FullyConnectedLayer(in_features=dim, out_features=dim) + self.k = FullyConnectedLayer(in_features=dim, out_features=dim) + self.v = FullyConnectedLayer(in_features=dim, out_features=dim) + self.proj = FullyConnectedLayer(in_features=dim, out_features=dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask_windows=None, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + norm_x = F.normalize(x, p=2.0, dim=-1) + q = self.q(norm_x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = self.k(norm_x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 3, 1) + v = self.v(x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k) * self.scale + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + if mask_windows is not None: + attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1) + attn = attn + attn_mask_windows.masked_fill(attn_mask_windows == 0, float(-100.0)).masked_fill( + attn_mask_windows == 1, float(0.0)) + with torch.no_grad(): + mask_windows = torch.clamp(torch.sum(mask_windows, dim=1, keepdim=True), 0, 1).repeat(1, N, 1) + + attn = self.softmax(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + return x, mask_windows + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, down_ratio=1, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + if self.shift_size > 0: + down_ratio = 1 + self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + down_ratio=down_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + proj_drop=drop) + + self.fuse = FullyConnectedLayer(in_features=dim * 2, out_features=dim, activation='lrelu') + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size, mask=None): + # H, W = self.input_resolution + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = x.view(B, H, W, C) + if mask is not None: + mask = mask.view(B, H, W, 1) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + if mask is not None: + shifted_mask = torch.roll(mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + if mask is not None: + shifted_mask = mask + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + if mask is not None: + mask_windows = window_partition(shifted_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1) + else: + mask_windows = None + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows, mask_windows = self.attn(x_windows, mask_windows, + mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows, mask_windows = self.attn(x_windows, mask_windows, mask=self.calculate_mask(x_size).to( + x.device)) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + if mask is not None: + mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1) + shifted_mask = window_reverse(mask_windows, self.window_size, H, W) + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + if mask is not None: + mask = torch.roll(shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + if mask is not None: + mask = shifted_mask + x = x.view(B, H * W, C) + if mask is not None: + mask = mask.view(B, H * W, 1) + + # FFN + x = self.fuse(torch.cat([shortcut, x], dim=-1)) + x = self.mlp(x) + + return x, mask + + +class PatchMerging(nn.Module): + def __init__(self, in_channels, out_channels, down=2): + super().__init__() + self.conv = Conv2dLayerPartial(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation='lrelu', + down=down, + ) + self.down = down + + def forward(self, x, x_size, mask=None): + x = token2feature(x, x_size) + if mask is not None: + mask = token2feature(mask, x_size) + x, mask = self.conv(x, mask) + if self.down != 1: + ratio = 1 / self.down + x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio)) + x = feature2token(x) + if mask is not None: + mask = feature2token(mask) + return x, x_size, mask + + +class PatchUpsampling(nn.Module): + def __init__(self, in_channels, out_channels, up=2): + super().__init__() + self.conv = Conv2dLayerPartial(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation='lrelu', + up=up, + ) + self.up = up + + def forward(self, x, x_size, mask=None): + x = token2feature(x, x_size) + if mask is not None: + mask = token2feature(mask, x_size) + x, mask = self.conv(x, mask) + if self.up != 1: + x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up)) + x = feature2token(x) + if mask is not None: + mask = feature2token(mask) + return x, x_size, mask + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, down_ratio=1, + mlp_ratio=2., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # patch merging layer + if downsample is not None: + # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + self.downsample = downsample + else: + self.downsample = None + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, down_ratio=down_ratio, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + self.conv = Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, activation='lrelu') + + def forward(self, x, x_size, mask=None): + if self.downsample is not None: + x, x_size, mask = self.downsample(x, x_size, mask) + identity = x + for blk in self.blocks: + if self.use_checkpoint: + x, mask = checkpoint.checkpoint(blk, x, x_size, mask) + else: + x, mask = blk(x, x_size, mask) + if mask is not None: + mask = token2feature(mask, x_size) + x, mask = self.conv(token2feature(x, x_size), mask) + x = feature2token(x) + identity + if mask is not None: + mask = feature2token(mask) + return x, x_size, mask + + +class ToToken(nn.Module): + def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1): + super().__init__() + + self.proj = Conv2dLayerPartial(in_channels=in_channels, out_channels=dim, kernel_size=kernel_size, + activation='lrelu') + + def forward(self, x, mask): + x, mask = self.proj(x, mask) + + return x, mask + + +class EncFromRGB(nn.Module): + def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2 + super().__init__() + self.conv0 = Conv2dLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + activation=activation, + ) + self.conv1 = Conv2dLayer(in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + ) + + def forward(self, x): + x = self.conv0(x) + x = self.conv1(x) + + return x + + +class ConvBlockDown(nn.Module): + def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log + super().__init__() + + self.conv0 = Conv2dLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + down=2, + ) + self.conv1 = Conv2dLayer(in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + ) + + def forward(self, x): + x = self.conv0(x) + x = self.conv1(x) + + return x + + +def token2feature(x, x_size): + B, N, C = x.shape + h, w = x_size + x = x.permute(0, 2, 1).reshape(B, C, h, w) + return x + + +def feature2token(x): + B, C, H, W = x.shape + x = x.view(B, C, -1).transpose(1, 2) + return x + + +class Encoder(nn.Module): + def __init__(self, res_log2, img_channels, activation, patch_size=5, channels=16, drop_path_rate=0.1): + super().__init__() + + self.resolution = [] + + for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16 + res = 2 ** i + self.resolution.append(res) + if i == res_log2: + block = EncFromRGB(img_channels * 2 + 1, nf(i), activation) + else: + block = ConvBlockDown(nf(i + 1), nf(i), activation) + setattr(self, 'EncConv_Block_%dx%d' % (res, res), block) + + def forward(self, x): + out = {} + for res in self.resolution: + res_log2 = int(np.log2(res)) + x = getattr(self, 'EncConv_Block_%dx%d' % (res, res))(x) + out[res_log2] = x + + return out + + +class ToStyle(nn.Module): + def __init__(self, in_channels, out_channels, activation, drop_rate): + super().__init__() + self.conv = nn.Sequential( + Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, + down=2), + Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, + down=2), + Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, + down=2), + ) + + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = FullyConnectedLayer(in_features=in_channels, + out_features=out_channels, + activation=activation) + # self.dropout = nn.Dropout(drop_rate) + + def forward(self, x): + x = self.conv(x) + x = self.pool(x) + x = self.fc(x.flatten(start_dim=1)) + # x = self.dropout(x) + + return x + + +class DecBlockFirstV2(nn.Module): + def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): + super().__init__() + self.res = res + + self.conv0 = Conv2dLayer(in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = StyleConv(in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB(in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, ws, gs, E_features, noise_mode='random'): + # x = self.fc(x).view(x.shape[0], -1, 4, 4) + x = self.conv0(x) + x = x + E_features[self.res] + style = get_style_code(ws[:, 0], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, 1], gs) + img = self.toRGB(x, style, skip=None) + + return x, img + + +class DecBlock(nn.Module): + def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, + img_channels): # res = 4, ..., resolution_log2 + super().__init__() + self.res = res + + self.conv0 = StyleConv(in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv(in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB(in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, img, ws, gs, E_features, noise_mode='random'): + style = get_style_code(ws[:, self.res * 2 - 9], gs) + x = self.conv0(x, style, noise_mode=noise_mode) + x = x + E_features[self.res] + style = get_style_code(ws[:, self.res * 2 - 8], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, self.res * 2 - 7], gs) + img = self.toRGB(x, style, skip=img) + + return x, img + + +class Decoder(nn.Module): + def __init__(self, res_log2, activation, style_dim, use_noise, demodulate, img_channels): + super().__init__() + self.Dec_16x16 = DecBlockFirstV2(4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels) + for res in range(5, res_log2 + 1): + setattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res), + DecBlock(res, nf(res - 1), nf(res), activation, style_dim, use_noise, demodulate, img_channels)) + self.res_log2 = res_log2 + + def forward(self, x, ws, gs, E_features, noise_mode='random'): + x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode) + for res in range(5, self.res_log2 + 1): + block = getattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res)) + x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode) + + return img + + +class DecStyleBlock(nn.Module): + def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): + super().__init__() + self.res = res + + self.conv0 = StyleConv(in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv(in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB(in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, img, style, skip, noise_mode='random'): + x = self.conv0(x, style, noise_mode=noise_mode) + x = x + skip + x = self.conv1(x, style, noise_mode=noise_mode) + img = self.toRGB(x, style, skip=img) + + return x, img + + +class FirstStage(nn.Module): + def __init__(self, img_channels, img_resolution=256, dim=180, w_dim=512, use_noise=False, demodulate=True, + activation='lrelu'): + super().__init__() + res = 64 + + self.conv_first = Conv2dLayerPartial(in_channels=img_channels + 1, out_channels=dim, kernel_size=3, + activation=activation) + self.enc_conv = nn.ModuleList() + down_time = int(np.log2(img_resolution // res)) + # 根据图片尺寸构建 swim transformer 的层数 + for i in range(down_time): # from input size to 64 + self.enc_conv.append( + Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation) + ) + + # from 64 -> 16 -> 64 + depths = [2, 3, 4, 3, 2] + ratios = [1, 1 / 2, 1 / 2, 2, 2] + num_heads = 6 + window_sizes = [8, 16, 16, 16, 8] + drop_path_rate = 0.1 + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + self.tran = nn.ModuleList() + for i, depth in enumerate(depths): + res = int(res * ratios[i]) + if ratios[i] < 1: + merge = PatchMerging(dim, dim, down=int(1 / ratios[i])) + elif ratios[i] > 1: + merge = PatchUpsampling(dim, dim, up=ratios[i]) + else: + merge = None + self.tran.append( + BasicLayer(dim=dim, input_resolution=[res, res], depth=depth, num_heads=num_heads, + window_size=window_sizes[i], drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], + downsample=merge) + ) + + # global style + down_conv = [] + for i in range(int(np.log2(16))): + down_conv.append( + Conv2dLayer(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation)) + down_conv.append(nn.AdaptiveAvgPool2d((1, 1))) + self.down_conv = nn.Sequential(*down_conv) + self.to_style = FullyConnectedLayer(in_features=dim, out_features=dim * 2, activation=activation) + self.ws_style = FullyConnectedLayer(in_features=w_dim, out_features=dim, activation=activation) + self.to_square = FullyConnectedLayer(in_features=dim, out_features=16 * 16, activation=activation) + + style_dim = dim * 3 + self.dec_conv = nn.ModuleList() + for i in range(down_time): # from 64 to input size + res = res * 2 + self.dec_conv.append( + DecStyleBlock(res, dim, dim, activation, style_dim, use_noise, demodulate, img_channels)) + + def forward(self, images_in, masks_in, ws, noise_mode='random'): + x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1) + + skips = [] + x, mask = self.conv_first(x, masks_in) # input size + skips.append(x) + for i, block in enumerate(self.enc_conv): # input size to 64 + x, mask = block(x, mask) + if i != len(self.enc_conv) - 1: + skips.append(x) + + x_size = x.size()[-2:] + x = feature2token(x) + mask = feature2token(mask) + mid = len(self.tran) // 2 + for i, block in enumerate(self.tran): # 64 to 16 + if i < mid: + x, x_size, mask = block(x, x_size, mask) + skips.append(x) + elif i > mid: + x, x_size, mask = block(x, x_size, None) + x = x + skips[mid - i] + else: + x, x_size, mask = block(x, x_size, None) + + mul_map = torch.ones_like(x) * 0.5 + mul_map = F.dropout(mul_map, training=True) + ws = self.ws_style(ws[:, -1]) + add_n = self.to_square(ws).unsqueeze(1) + add_n = F.interpolate(add_n, size=x.size(1), mode='linear', align_corners=False).squeeze(1).unsqueeze( + -1) + x = x * mul_map + add_n * (1 - mul_map) + gs = self.to_style(self.down_conv(token2feature(x, x_size)).flatten(start_dim=1)) + style = torch.cat([gs, ws], dim=1) + + x = token2feature(x, x_size).contiguous() + img = None + for i, block in enumerate(self.dec_conv): + x, img = block(x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode) + + # ensemble + img = img * (1 - masks_in) + images_in * masks_in + + return img + + +class SynthesisNet(nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels=3, # Number of color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_decay=1.0, + channel_max=512, # Maximum number of channels in any layer. + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + drop_rate=0.5, + use_noise=False, + demodulate=True, + ): + super().__init__() + resolution_log2 = int(np.log2(img_resolution)) + assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + + self.num_layers = resolution_log2 * 2 - 3 * 2 + self.img_resolution = img_resolution + self.resolution_log2 = resolution_log2 + + # first stage + self.first_stage = FirstStage(img_channels, img_resolution=img_resolution, w_dim=w_dim, use_noise=False, + demodulate=demodulate) + + # second stage + self.enc = Encoder(resolution_log2, img_channels, activation, patch_size=5, channels=16) + self.to_square = FullyConnectedLayer(in_features=w_dim, out_features=16 * 16, activation=activation) + self.to_style = ToStyle(in_channels=nf(4), out_channels=nf(2) * 2, activation=activation, drop_rate=drop_rate) + style_dim = w_dim + nf(2) * 2 + self.dec = Decoder(resolution_log2, activation, style_dim, use_noise, demodulate, img_channels) + + def forward(self, images_in, masks_in, ws, noise_mode='random', return_stg1=False): + out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode) + + # encoder + x = images_in * masks_in + out_stg1 * (1 - masks_in) + x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1) + E_features = self.enc(x) + + fea_16 = E_features[4] + mul_map = torch.ones_like(fea_16) * 0.5 + mul_map = F.dropout(mul_map, training=True) + add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1) + add_n = F.interpolate(add_n, size=fea_16.size()[-2:], mode='bilinear', align_corners=False) + fea_16 = fea_16 * mul_map + add_n * (1 - mul_map) + E_features[4] = fea_16 + + # style + gs = self.to_style(fea_16) + + # decoder + img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode) + + # ensemble + img = img * (1 - masks_in) + images_in * masks_in + + if not return_stg1: + return img + else: + return img, out_stg1 + + +class Generator(nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # resolution of generated image + img_channels, # Number of input color channels. + synthesis_kwargs={}, # Arguments for SynthesisNetwork. + mapping_kwargs={}, # Arguments for MappingNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + self.synthesis = SynthesisNet(w_dim=w_dim, + img_resolution=img_resolution, + img_channels=img_channels, + **synthesis_kwargs) + self.mapping = MappingNet(z_dim=z_dim, + c_dim=c_dim, + w_dim=w_dim, + num_ws=self.synthesis.num_layers, + **mapping_kwargs) + + def forward(self, images_in, masks_in, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False, + noise_mode='none', return_stg1=False): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, + skip_w_avg_update=skip_w_avg_update) + img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode) + return img + + +class Discriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + channel_decay=1, + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + activation='lrelu', + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + resolution_log2 = int(np.log2(img_resolution)) + assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + self.resolution_log2 = resolution_log2 + + if cmap_dim == None: + cmap_dim = nf(2) + if c_dim == 0: + cmap_dim = 0 + self.cmap_dim = cmap_dim + + if c_dim > 0: + self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None) + + Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)] + for res in range(resolution_log2, 2, -1): + Dis.append(DisBlock(nf(res), nf(res - 1), activation)) + + if mbstd_num_channels > 0: + Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels)) + Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation)) + self.Dis = nn.Sequential(*Dis) + + self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation) + self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) + + # for 64x64 + Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)] + for res in range(resolution_log2, 2, -1): + Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation)) + + if mbstd_num_channels > 0: + Dis_stg1.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels)) + Dis_stg1.append(Conv2dLayer(nf(2) // 2 + mbstd_num_channels, nf(2) // 2, kernel_size=3, activation=activation)) + self.Dis_stg1 = nn.Sequential(*Dis_stg1) + + self.fc0_stg1 = FullyConnectedLayer(nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation) + self.fc1_stg1 = FullyConnectedLayer(nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim) + + def forward(self, images_in, masks_in, images_stg1, c): + x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1)) + x = self.fc1(self.fc0(x.flatten(start_dim=1))) + + x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1)) + x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1))) + + if self.c_dim > 0: + cmap = self.mapping(None, c) + + if self.cmap_dim > 0: + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + return x, x_stg1 + + +MAT_MODEL_URL = os.environ.get( + "MAT_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth", +) + + +class MAT(InpaintModel): + min_size = 512 + pad_mod = 512 + pad_to_square = True + + def init_model(self, device, **kwargs): + seed = 240 # pick up a random number + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3) + self.model = load_model(G, MAT_MODEL_URL, device) + self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512] + self.label = torch.zeros([1, self.model.c_dim], device=device) + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL)) + + def forward(self, image, mask, config: Config): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W] mask area == 255 + return: BGR IMAGE + """ + + image = norm_img(image) # [0, 1] + image = image * 2 - 1 # [0, 1] -> [-1, 1] + + mask = (mask > 127) * 255 + mask = 255 - mask + mask = norm_img(mask) + + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + + output = self.model(image, mask, self.z, self.label, truncation_psi=1, noise_mode='none') + output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8) + output = output[0].cpu().numpy() + cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return cur_res diff --git a/lama_cleaner/model/opencv2.py b/lama_cleaner/model/opencv2.py new file mode 100644 index 0000000000000000000000000000000000000000..9a887763fd56494c765fb0f9545dd52f5a385a82 --- /dev/null +++ b/lama_cleaner/model/opencv2.py @@ -0,0 +1,25 @@ +import cv2 + +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + +flag_map = { + "INPAINT_NS": cv2.INPAINT_NS, + "INPAINT_TELEA": cv2.INPAINT_TELEA +} + +class OpenCV2(InpaintModel): + pad_mod = 1 + + @staticmethod + def is_downloaded() -> bool: + return True + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] + return: BGR IMAGE + """ + cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag]) + return cur_res diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py new file mode 100644 index 0000000000000000000000000000000000000000..b946a08a593b59fa9b177c15e2f97b444af0f967 --- /dev/null +++ b/lama_cleaner/model/paint_by_example.py @@ -0,0 +1,122 @@ +import random + +import PIL +import PIL.Image +import cv2 +import numpy as np +import torch +from diffusers import DiffusionPipeline +from loguru import logger + +from lama_cleaner.helper import resize_max_size +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + + +class PaintByExample(InpaintModel): + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + fp16 = not kwargs.get('no_half', False) + use_gpu = device == torch.device('cuda') and torch.cuda.is_available() + torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)} + + if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False): + logger.info("Disable Paint By Example Model NSFW checker") + model_kwargs.update(dict( + safety_checker=None, + requires_safety_checker=False + )) + + self.model = DiffusionPipeline.from_pretrained( + "Fantasy-Studio/Paint-by-Example", + torch_dtype=torch_dtype, + **model_kwargs + ) + + self.model.enable_attention_slicing() + if kwargs.get('enable_xformers', False): + self.model.enable_xformers_memory_efficient_attention() + + # TODO: gpu_id + if kwargs.get('cpu_offload', False) and use_gpu: + self.model.image_encoder = self.model.image_encoder.to(device) + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + seed = config.paint_by_example_seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + output = self.model( + image=PIL.Image.fromarray(image), + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + example_image=config.paint_by_example_example_image, + num_inference_steps=config.paint_by_example_steps, + output_type='np.array', + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + def _scaled_pad_forward(self, image, mask, config: Config): + longer_side_length = int(config.sd_scale * max(image.shape[:2])) + origin_size = image.shape[:2] + downsize_image = resize_max_size(image, size_limit=longer_side_length) + downsize_mask = resize_max_size(mask, size_limit=longer_side_length) + logger.info( + f"Resize image to do paint_by_example: {image.shape} -> {downsize_image.shape}" + ) + inpaint_result = self._pad_forward(downsize_image, downsize_mask, config) + # only paste masked area result + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) + original_pixel_indices = mask < 127 + inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices] + return inpaint_result + + @torch.no_grad() + def __call__(self, image, mask, config: Config): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + if config.use_croper: + crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config) + crop_image = self._scaled_pad_forward(crop_img, crop_mask, config) + inpaint_result = image[:, :, ::-1] + inpaint_result[t:b, l:r, :] = crop_image + else: + inpaint_result = self._scaled_pad_forward(image, mask, config) + + return inpaint_result + + def forward_post_process(self, result, image, mask, config): + if config.paint_by_example_match_histograms: + result = self._match_histograms(result, image[:, :, ::-1], mask) + + if config.paint_by_example_mask_blur != 0: + k = 2 * config.paint_by_example_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0) + return result, image, mask + + @staticmethod + def is_downloaded() -> bool: + # model will be downloaded when app start, and can't switch in frontend settings + return True diff --git a/lama_cleaner/model/plms_sampler.py b/lama_cleaner/model/plms_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2e822abfe0ddd082b8ed4dd64de36774efdbb2 --- /dev/null +++ b/lama_cleaner/model/plms_sampler.py @@ -0,0 +1,226 @@ +# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py +import numpy as np +import torch +from tqdm import tqdm + +from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + steps, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + return img + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py new file mode 100644 index 0000000000000000000000000000000000000000..08585bfe556e105e60ec93c6648253934803a2cc --- /dev/null +++ b/lama_cleaner/model/sd.py @@ -0,0 +1,188 @@ +import random + +import PIL.Image +import cv2 +import numpy as np +import torch +from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \ + EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler +from loguru import logger + +from lama_cleaner.helper import resize_max_size +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.model.utils import torch_gc +from lama_cleaner.schema import Config, SDSampler + + +class CPUTextEncoderWrapper: + def __init__(self, text_encoder, torch_dtype): + self.config = text_encoder.config + self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True) + self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) + self.torch_dtype = torch_dtype + del text_encoder + torch_gc() + + def __call__(self, x, **kwargs): + input_device = x.device + return [self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0].to(input_device).to(self.torch_dtype)] + + +class SD(InpaintModel): + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline + fp16 = not kwargs.get('no_half', False) + + model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])} + if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False): + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update(dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False + )) + + use_gpu = device == torch.device('cuda') and torch.cuda.is_available() + torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + self.model = StableDiffusionInpaintPipeline.from_pretrained( + self.model_id_or_path, + revision="fp16" if use_gpu and fp16 else "main", + torch_dtype=torch_dtype, + use_auth_token=kwargs["hf_access_token"], + **model_kwargs + ) + + # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing + self.model.enable_attention_slicing() + # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention + if kwargs.get('enable_xformers', False): + self.model.enable_xformers_memory_efficient_attention() + + if kwargs.get('cpu_offload', False) and use_gpu: + # TODO: gpu_id + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + if kwargs['sd_cpu_textencoder']: + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype) + + self.callback = kwargs.pop("callback", None) + + def _scaled_pad_forward(self, image, mask, config: Config): + longer_side_length = int(config.sd_scale * max(image.shape[:2])) + origin_size = image.shape[:2] + downsize_image = resize_max_size(image, size_limit=longer_side_length) + downsize_mask = resize_max_size(mask, size_limit=longer_side_length) + logger.info( + f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}" + ) + inpaint_result = self._pad_forward(downsize_image, downsize_mask, config) + # only paste masked area result + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) + original_pixel_indices = mask < 127 + inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices] + return inpaint_result + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + + scheduler_config = self.model.scheduler.config + + if config.sd_sampler == SDSampler.ddim: + scheduler = DDIMScheduler.from_config(scheduler_config) + elif config.sd_sampler == SDSampler.pndm: + scheduler = PNDMScheduler.from_config(scheduler_config) + elif config.sd_sampler == SDSampler.k_lms: + scheduler = LMSDiscreteScheduler.from_config(scheduler_config) + elif config.sd_sampler == SDSampler.k_euler: + scheduler = EulerDiscreteScheduler.from_config(scheduler_config) + elif config.sd_sampler == SDSampler.k_euler_a: + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) + elif config.sd_sampler == SDSampler.dpm_plus_plus: + scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) + else: + raise ValueError(config.sd_sampler) + + self.model.scheduler = scheduler + + seed = config.sd_seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] + + img_h, img_w = image.shape[:2] + + output = self.model( + image=PIL.Image.fromarray(image), + prompt=config.prompt, + negative_prompt=config.negative_prompt, + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + output_type="np.array", + callback=self.callback, + height=img_h, + width=img_w, + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + @torch.no_grad() + def __call__(self, image, mask, config: Config): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + # boxes = boxes_from_mask(mask) + if config.use_croper: + crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config) + crop_image = self._scaled_pad_forward(crop_img, crop_mask, config) + inpaint_result = image[:, :, ::-1] + inpaint_result[t:b, l:r, :] = crop_image + else: + inpaint_result = self._scaled_pad_forward(image, mask, config) + + return inpaint_result + + def forward_post_process(self, result, image, mask, config): + if config.sd_match_histograms: + result = self._match_histograms(result, image[:, :, ::-1], mask) + + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0) + return result, image, mask + + @staticmethod + def is_downloaded() -> bool: + # model will be downloaded when app start, and can't switch in frontend settings + return True + + +class SD15(SD): + model_id_or_path = "runwayml/stable-diffusion-inpainting" + + +class SD2(SD): + model_id_or_path = "stabilityai/stable-diffusion-2-inpainting" diff --git a/lama_cleaner/model/utils.py b/lama_cleaner/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f2d51163b1974b834c567522263441687d141c --- /dev/null +++ b/lama_cleaner/model/utils.py @@ -0,0 +1,714 @@ +import collections +import math +from itertools import repeat +from typing import Any + +import numpy as np +import torch +from torch import conv2d, conv_transpose2d + + +def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s).to(device) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2).to(device) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=device) + + args = timesteps[:, None].float() * freqs[None] + + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +###### MAT and FcF ####### + + +def normalize_2nd_moment(x, dim=1, eps=1e-8): + return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + + +def _get_filter_size(f): + if f is None: + return 1, 1 + + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + + fw = int(fw) + fh = int(fh) + assert fw >= 1 and fh >= 1 + return fw, fh + + +def _get_weight_shape(w): + shape = [int(sz) for sz in w.shape] + return shape + + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + +activation_funcs = { + 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, + ref='y', has_2nd_grad=False), + 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, + def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', + has_2nd_grad=True), + 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', + has_2nd_grad=True), + 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', + has_2nd_grad=True), + 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', + has_2nd_grad=True), + 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, + ref='y', has_2nd_grad=True), + 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', + has_2nd_grad=True), +} + + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # assert isinstance(x, torch.Tensor) + # assert impl in ['ref', 'cuda'] + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + + +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + # upx, upy = _parse_scaling(up) + # downx, downy = _parse_scaling(down) + + upx, upy = up, up + downx, downy = down, down + + # padx0, padx1, pady0, pady1 = _parse_padding(padding) + padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3] + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + # padx0, padx1, pady0, pady1 = _parse_padding(padding) + padx0, padx1, pady0, pady1 = padding, padding, padding, padding + + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + # upx, upy = up, up + padx0, padx1, pady0, pady1 = _parse_padding(padding) + # padx0, padx1, pady0, pady1 = padding, padding, padding, padding + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl) + + +class MinibatchStdLayer(torch.nn.Module): + def __init__(self, group_size, num_channels=1): + super().__init__() + self.group_size = group_size + self.num_channels = num_channels + + def forward(self, x): + N, C, H, W = x.shape + G = torch.min(torch.as_tensor(self.group_size), + torch.as_tensor(N)) if self.group_size is not None else N + F = self.num_channels + c = C // F + + y = x.reshape(G, -1, F, c, H, + W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. + y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. + y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. + y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. + y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels. + y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. + y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. + x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. + return x + + +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias=True, # Apply additive bias before the activation function? + activation='linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=1, # Learning rate multiplier. + bias_init=0, # Initial value for the additive bias. + ): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.activation = activation + + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight * self.weight_gain + b = self.bias + if b is not None and self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == 'linear' and b is not None: + # out = torch.addmm(b.unsqueeze(0), x, w.t()) + x = x.matmul(w.t()) + out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)]) + else: + x = x.matmul(w.t()) + out = bias_act(x, b, act=self.activation, dim=x.ndim - 1) + return out + + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + w = w.flip([2, 3]) + + # Workaround performance pitfall in cuDNN 8.0.5, triggered when using + # 1x1 kernel + memory_format=channels_last + less than 64 channels. + if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: + if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: + if out_channels <= 4 and groups == 1: + in_shape = x.shape + x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) + x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) + else: + x = x.to(memory_format=torch.contiguous_format) + w = w.to(memory_format=torch.contiguous_format) + x = conv2d(x, w, groups=groups) + return x.to(memory_format=torch.channels_last) + + # Otherwise => execute using conv2d_gradfix. + op = conv_transpose2d if transpose else conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + + +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + # px0, px1, py0, py1 = _parse_padding(padding) + px0, px1, py0, py1 = padding, padding, padding, padding + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True, + flip_weight=(not flip_weight)) + x = upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2, + flip_filter=flip_filter) + if down > 1: + x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight) + + # Fallback: Generic reference implementation. + x = upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2, + flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + +class Conv2dLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias=True, # Apply additive bias before the activation function? + activation='linear', # Activation function: 'relu', 'lrelu', etc. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + channels_last=False, # Expect the input to have memory_format=channels_last? + trainable=True, # Update the weights of this layer during training? + ): + super().__init__() + self.activation = activation + self.up = up + self.down = down + self.register_buffer('resample_filter', setup_filter(resample_filter)) + self.conv_clamp = conv_clamp + self.padding = kernel_size // 2 + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.act_gain = activation_funcs[activation].def_gain + + memory_format = torch.channels_last if channels_last else torch.contiguous_format + weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) + bias = torch.zeros([out_channels]) if bias else None + if trainable: + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + else: + self.register_buffer('weight', weight) + if bias is not None: + self.register_buffer('bias', bias) + else: + self.bias = None + + def forward(self, x, gain=1): + w = self.weight * self.weight_gain + x = conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down, + padding=self.padding) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp) + return out + + +def torch_gc(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() diff --git a/lama_cleaner/model/zits.py b/lama_cleaner/model/zits.py new file mode 100644 index 0000000000000000000000000000000000000000..0360669ee4af51feb8d1c1d1f4b8111ee8ede48b --- /dev/null +++ b/lama_cleaner/model/zits.py @@ -0,0 +1,427 @@ +import os +import time + +import cv2 +import numpy as np +import skimage +import torch +import torch.nn.functional as F +from skimage import color, feature + +from lama_cleaner.helper import get_cache_path_by_url, load_jit_model +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + +ZITS_INPAINT_MODEL_URL = os.environ.get( + "ZITS_INPAINT_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt", +) + +ZITS_EDGE_LINE_MODEL_URL = os.environ.get( + "ZITS_EDGE_LINE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt", +) + +ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get( + "ZITS_STRUCTURE_UPSAMPLE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt", +) + +ZITS_WIRE_FRAME_MODEL_URL = os.environ.get( + "ZITS_WIRE_FRAME_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt", +) + + +def resize(img, height, width, center_crop=False): + imgh, imgw = img.shape[0:2] + + if center_crop and imgh != imgw: + # center crop + side = np.minimum(imgh, imgw) + j = (imgh - side) // 2 + i = (imgw - side) // 2 + img = img[j : j + side, i : i + side, ...] + + if imgh > height and imgw > width: + inter = cv2.INTER_AREA + else: + inter = cv2.INTER_LINEAR + img = cv2.resize(img, (height, width), interpolation=inter) + + return img + + +def to_tensor(img, scale=True, norm=False): + if img.ndim == 2: + img = img[:, :, np.newaxis] + c = img.shape[-1] + + if scale: + img_t = torch.from_numpy(img).permute(2, 0, 1).float().div(255) + else: + img_t = torch.from_numpy(img).permute(2, 0, 1).float() + + if norm: + mean = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1) + std = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1) + img_t = (img_t - mean) / std + return img_t + + +def load_masked_position_encoding(mask): + ones_filter = np.ones((3, 3), dtype=np.float32) + d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32) + d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32) + d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32) + d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32) + str_size = 256 + pos_num = 128 + + ori_mask = mask.copy() + ori_h, ori_w = ori_mask.shape[0:2] + ori_mask = ori_mask / 255 + mask = cv2.resize(mask, (str_size, str_size), interpolation=cv2.INTER_AREA) + mask[mask > 0] = 255 + h, w = mask.shape[0:2] + mask3 = mask.copy() + mask3 = 1.0 - (mask3 / 255.0) + pos = np.zeros((h, w), dtype=np.int32) + direct = np.zeros((h, w, 4), dtype=np.int32) + i = 0 + while np.sum(1 - mask3) > 0: + i += 1 + mask3_ = cv2.filter2D(mask3, -1, ones_filter) + mask3_[mask3_ > 0] = 1 + sub_mask = mask3_ - mask3 + pos[sub_mask == 1] = i + + m = cv2.filter2D(mask3, -1, d_filter1) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 0] = 1 + + m = cv2.filter2D(mask3, -1, d_filter2) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 1] = 1 + + m = cv2.filter2D(mask3, -1, d_filter3) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 2] = 1 + + m = cv2.filter2D(mask3, -1, d_filter4) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 3] = 1 + + mask3 = mask3_ + + abs_pos = pos.copy() + rel_pos = pos / (str_size / 2) # to 0~1 maybe larger than 1 + rel_pos = (rel_pos * pos_num).astype(np.int32) + rel_pos = np.clip(rel_pos, 0, pos_num - 1) + + if ori_w != w or ori_h != h: + rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST) + rel_pos[ori_mask == 0] = 0 + direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST) + direct[ori_mask == 0, :] = 0 + + return rel_pos, abs_pos, direct + + +def load_image(img, mask, device, sigma256=3.0): + """ + Args: + img: [H, W, C] RGB + mask: [H, W] 255 为 masks 区域 + sigma256: + + Returns: + + """ + h, w, _ = img.shape + imgh, imgw = img.shape[0:2] + img_256 = resize(img, 256, 256) + + mask = (mask > 127).astype(np.uint8) * 255 + mask_256 = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA) + mask_256[mask_256 > 0] = 255 + + mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA) + mask_512[mask_512 > 0] = 255 + + # original skimage implemention + # https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny + # low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max. + # high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max. + gray_256 = color.rgb2gray(img_256) + edge_256 = feature.canny(gray_256, sigma=sigma256, mask=None).astype(float) + # cv2.imwrite("skimage_gray.jpg", (_gray_256*255).astype(np.uint8)) + # cv2.imwrite("skimage_edge.jpg", (_edge_256*255).astype(np.uint8)) + + # gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY) + # gray_256_blured = cv2.GaussianBlur(gray_256, ksize=(3,3), sigmaX=sigma256, sigmaY=sigma256) + # edge_256 = cv2.Canny(gray_256_blured, threshold1=int(255*0.1), threshold2=int(255*0.2)) + # cv2.imwrite("edge.jpg", edge_256) + + # line + img_512 = resize(img, 512, 512) + + rel_pos, abs_pos, direct = load_masked_position_encoding(mask) + + batch = dict() + batch["images"] = to_tensor(img.copy()).unsqueeze(0).to(device) + batch["img_256"] = to_tensor(img_256, norm=True).unsqueeze(0).to(device) + batch["masks"] = to_tensor(mask).unsqueeze(0).to(device) + batch["mask_256"] = to_tensor(mask_256).unsqueeze(0).to(device) + batch["mask_512"] = to_tensor(mask_512).unsqueeze(0).to(device) + batch["edge_256"] = to_tensor(edge_256, scale=False).unsqueeze(0).to(device) + batch["img_512"] = to_tensor(img_512).unsqueeze(0).to(device) + batch["rel_pos"] = torch.LongTensor(rel_pos).unsqueeze(0).to(device) + batch["abs_pos"] = torch.LongTensor(abs_pos).unsqueeze(0).to(device) + batch["direct"] = torch.LongTensor(direct).unsqueeze(0).to(device) + batch["h"] = imgh + batch["w"] = imgw + + return batch + + +def to_device(data, device): + if isinstance(data, torch.Tensor): + return data.to(device) + if isinstance(data, dict): + for key in data: + if isinstance(data[key], torch.Tensor): + data[key] = data[key].to(device) + return data + if isinstance(data, list): + return [to_device(d, device) for d in data] + + +class ZITS(InpaintModel): + min_size = 256 + pad_mod = 32 + pad_to_square = True + + def __init__(self, device, **kwargs): + """ + + Args: + device: + """ + super().__init__(device) + self.device = device + self.sample_edge_line_iterations = 1 + + def init_model(self, device, **kwargs): + self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device) + self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device) + self.structure_upsample = load_jit_model( + ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device + ) + self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device) + + @staticmethod + def is_downloaded() -> bool: + model_paths = [ + get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL), + get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL), + get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL), + get_cache_path_by_url(ZITS_INPAINT_MODEL_URL), + ] + return all([os.path.exists(it) for it in model_paths]) + + def wireframe_edge_and_line(self, items, enable: bool): + # 最终向 items 中添加 edge 和 line key + if not enable: + items["edge"] = torch.zeros_like(items["masks"]) + items["line"] = torch.zeros_like(items["masks"]) + return + + start = time.time() + try: + line_256 = self.wireframe_forward( + items["img_512"], + h=256, + w=256, + masks=items["mask_512"], + mask_th=0.85, + ) + except: + line_256 = torch.zeros_like(items["mask_256"]) + + print(f"wireframe_forward time: {(time.time() - start) * 1000:.2f}ms") + + # np_line = (line[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("line.jpg", np_line) + + start = time.time() + edge_pred, line_pred = self.sample_edge_line_logits( + context=[items["img_256"], items["edge_256"], line_256], + mask=items["mask_256"].clone(), + iterations=self.sample_edge_line_iterations, + add_v=0.05, + mul_v=4, + ) + print(f"sample_edge_line_logits time: {(time.time() - start) * 1000:.2f}ms") + + # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("edge_pred.jpg", np_edge_pred) + # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("line_pred.jpg", np_line_pred) + # exit() + + input_size = min(items["h"], items["w"]) + if input_size != 256 and input_size > 256: + while edge_pred.shape[2] < input_size: + edge_pred = self.structure_upsample(edge_pred) + edge_pred = torch.sigmoid((edge_pred + 2) * 2) + + line_pred = self.structure_upsample(line_pred) + line_pred = torch.sigmoid((line_pred + 2) * 2) + + edge_pred = F.interpolate( + edge_pred, + size=(input_size, input_size), + mode="bilinear", + align_corners=False, + ) + line_pred = F.interpolate( + line_pred, + size=(input_size, input_size), + mode="bilinear", + align_corners=False, + ) + + # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred) + # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("line_pred_upsample.jpg", np_line_pred) + # exit() + + items["edge"] = edge_pred.detach() + items["line"] = line_pred.detach() + + @torch.no_grad() + def forward(self, image, mask, config: Config): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W] + return: BGR IMAGE + """ + mask = mask[:, :, 0] + items = load_image(image, mask, device=self.device) + + self.wireframe_edge_and_line(items, config.zits_wireframe) + + inpainted_image = self.inpaint( + items["images"], + items["masks"], + items["edge"], + items["line"], + items["rel_pos"], + items["direct"], + ) + + inpainted_image = inpainted_image * 255.0 + inpainted_image = ( + inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8) + ) + inpainted_image = inpainted_image[:, :, ::-1] + + # cv2.imwrite("inpainted.jpg", inpainted_image) + # exit() + + return inpainted_image + + def wireframe_forward(self, images, h, w, masks, mask_th=0.925): + lcnn_mean = torch.tensor([109.730, 103.832, 98.681]).reshape(1, 3, 1, 1) + lcnn_std = torch.tensor([22.275, 22.124, 23.229]).reshape(1, 3, 1, 1) + images = images * 255.0 + # the masks value of lcnn is 127.5 + masked_images = images * (1 - masks) + torch.ones_like(images) * masks * 127.5 + masked_images = (masked_images - lcnn_mean) / lcnn_std + + def to_int(x): + return tuple(map(int, x)) + + lines_tensor = [] + lmap = np.zeros((h, w)) + + output_masked = self.wireframe(masked_images) + + output_masked = to_device(output_masked, "cpu") + if output_masked["num_proposals"] == 0: + lines_masked = [] + scores_masked = [] + else: + lines_masked = output_masked["lines_pred"].numpy() + lines_masked = [ + [line[1] * h, line[0] * w, line[3] * h, line[2] * w] + for line in lines_masked + ] + scores_masked = output_masked["lines_score"].numpy() + + for line, score in zip(lines_masked, scores_masked): + if score > mask_th: + rr, cc, value = skimage.draw.line_aa( + *to_int(line[0:2]), *to_int(line[2:4]) + ) + lmap[rr, cc] = np.maximum(lmap[rr, cc], value) + + lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8) + lines_tensor.append(to_tensor(lmap).unsqueeze(0)) + + lines_tensor = torch.cat(lines_tensor, dim=0) + return lines_tensor.detach().to(self.device) + + def sample_edge_line_logits( + self, context, mask=None, iterations=1, add_v=0, mul_v=4 + ): + [img, edge, line] = context + + img = img * (1 - mask) + edge = edge * (1 - mask) + line = line * (1 - mask) + + for i in range(iterations): + edge_logits, line_logits = self.edge_line(img, edge, line, masks=mask) + + edge_pred = torch.sigmoid(edge_logits) + line_pred = torch.sigmoid((line_logits + add_v) * mul_v) + edge = edge + edge_pred * mask + edge[edge >= 0.25] = 1 + edge[edge < 0.25] = 0 + line = line + line_pred * mask + + b, _, h, w = edge_pred.shape + edge_pred = edge_pred.reshape(b, -1, 1) + line_pred = line_pred.reshape(b, -1, 1) + mask = mask.reshape(b, -1) + + edge_probs = torch.cat([1 - edge_pred, edge_pred], dim=-1) + line_probs = torch.cat([1 - line_pred, line_pred], dim=-1) + edge_probs[:, :, 1] += 0.5 + line_probs[:, :, 1] += 0.5 + edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100) + line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100) + + indices = torch.sort( + edge_max_probs + line_max_probs, dim=-1, descending=True + )[1] + + for ii in range(b): + keep = int((i + 1) / iterations * torch.sum(mask[ii, ...])) + + assert torch.sum(mask[ii][indices[ii, :keep]]) == keep, "Error!!!" + mask[ii][indices[ii, :keep]] = 0 + + mask = mask.reshape(b, 1, h, w) + edge = edge * (1 - mask) + line = line * (1 - mask) + + edge, line = edge.to(torch.float32), line.to(torch.float32) + return edge, line diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..f749af132adb27aba2813ee0907b0d05cc2c4727 --- /dev/null +++ b/lama_cleaner/model_manager.py @@ -0,0 +1,56 @@ +import gc + +import torch + +from lama_cleaner.model.fcf import FcF +from lama_cleaner.model.lama import LaMa +from lama_cleaner.model.ldm import LDM +from lama_cleaner.model.manga import Manga +from lama_cleaner.model.mat import MAT +from lama_cleaner.model.opencv2 import OpenCV2 +from lama_cleaner.model.paint_by_example import PaintByExample +from lama_cleaner.model.sd import SD15, SD2 +from lama_cleaner.model.zits import ZITS +from lama_cleaner.schema import Config + +models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga, + "sd2": SD2, "paint_by_example": PaintByExample} + + +class ModelManager: + def __init__(self, model_device, **kwargs): + self.name = "lama" + self.device = model_device + self.kwargs = kwargs + self.model = self.init_model(self.name, model_device, **kwargs) + + def init_model(self, name: str, device, **kwargs): + if name in models: + model = models[name](device, **kwargs) + else: + raise NotImplementedError(f"Not supported model: {name}") + return model + + def is_downloaded(self, name: str) -> bool: + if name in models: + return models[name].is_downloaded() + else: + raise NotImplementedError(f"Not supported model: {name}") + + def __call__(self, image, mask, config: Config): + return self.model(image, mask, config) + + def switch(self, new_name: str): + if new_name == self.name: + return + try: + if (torch.cuda.memory_allocated() > 0): + # Clear current loaded model from memory + torch.cuda.empty_cache() + del self.model + gc.collect() + + self.model = self.init_model(new_name, self.device, **self.kwargs) + self.name = new_name + except NotImplementedError as e: + raise e diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py new file mode 100644 index 0000000000000000000000000000000000000000..fc78328e2fc36ee6ab7dcc6f362daebaf55843bf --- /dev/null +++ b/lama_cleaner/parse_args.py @@ -0,0 +1,128 @@ +import argparse +import imghdr +import os +from pathlib import Path + +from loguru import logger + +from lama_cleaner.const import AVAILABLE_MODELS, NO_HALF_HELP, CPU_OFFLOAD_HELP, DISABLE_NSFW_HELP, \ + SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, AVAILABLE_DEVICES, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, \ + OUTPUT_DIR_HELP, INPUT_HELP, GUI_HELP, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR +from lama_cleaner.runtime import dump_environment_info + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", default=8080, type=int) + + parser.add_argument("--config-installer", action="store_true", + help="Open config web page, mainly for windows installer") + parser.add_argument("--load-installer-config", action="store_true", + help="Load all cmd args from installer config file") + parser.add_argument("--installer-config", default=None, help="Config file for windows installer") + + parser.add_argument("--model", default="lama", choices=AVAILABLE_MODELS) + parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP) + parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP) + parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP) + parser.add_argument("--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP) + parser.add_argument("--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP) + parser.add_argument("--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP) + parser.add_argument("--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES) + parser.add_argument("--gui", action="store_true", help=GUI_HELP) + parser.add_argument("--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP) + parser.add_argument( + "--gui-size", + default=[1600, 1000], + nargs=2, + type=int, + help="Set window size for GUI", + ) + parser.add_argument("--input", type=str, default=None, help=INPUT_HELP) + parser.add_argument("--output-dir", type=str, default=None, help=OUTPUT_DIR_HELP) + parser.add_argument("--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP) + parser.add_argument("--disable-model-switch", action="store_true", help="Disable model switch in frontend") + parser.add_argument("--debug", action="store_true") + + # useless args + parser.add_argument( + "--hf_access_token", + default="", + help="SD model no more need token: https://github.com/huggingface/diffusers/issues/1447", + ) + parser.add_argument( + "--sd-disable-nsfw", + action="store_true", + help="Disable Stable Diffusion NSFW checker", + ) + parser.add_argument( + "--sd-run-local", + action="store_true", + help="SD model no more need token, use --local-files-only to set not connect to huggingface server", + ) + parser.add_argument( + "--sd-enable-xformers", + action="store_true", + help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers" + ) + + args = parser.parse_args() + + # collect system info to help debug + dump_environment_info() + + if args.config_installer: + if args.installer_config is None: + parser.error(f"args.config_installer==True, must set args.installer_config to store config file") + from lama_cleaner.web_config import main + logger.info(f"Launching installer web config page") + main(args.installer_config) + exit() + + if args.load_installer_config: + from lama_cleaner.web_config import load_config + if args.installer_config and not os.path.exists(args.installer_config): + parser.error(f"args.installer_config={args.installer_config} not exists") + + logger.info(f"Loading installer config from {args.installer_config}") + _args = load_config(args.installer_config) + for k, v in vars(_args).items(): + if k in vars(args): + setattr(args, k, v) + + if args.device == "cuda": + import torch + if torch.cuda.is_available() is False: + parser.error( + "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation") + + if args.model_dir and args.model_dir is not None: + if os.path.isfile(args.model_dir): + parser.error(f"invalid --model-dir: {args.model_dir} is a file") + + if not os.path.exists(args.model_dir): + logger.info(f"Create model cache directory: {args.model_dir}") + Path(args.model_dir).mkdir(exist_ok=True, parents=True) + + os.environ["XDG_CACHE_HOME"] = args.model_dir + + if args.input and args.input is not None: + if not os.path.exists(args.input): + parser.error(f"invalid --input: {args.input} not exists") + if os.path.isfile(args.input): + if imghdr.what(args.input) is None: + parser.error(f"invalid --input: {args.input} is not a valid image file") + else: + if args.output_dir is None: + parser.error(f"invalid --input: {args.input} is a directory, --output-dir is required") + else: + output_dir = Path(args.output_dir) + if not output_dir.exists(): + logger.info(f"Creating output directory: {output_dir}") + output_dir.mkdir(parents=True) + else: + if not output_dir.is_dir(): + parser.error(f"invalid --output-dir: {output_dir} is not a directory") + + return args diff --git a/lama_cleaner/plugins/__pycache__/base_plugin.cpython-38.pyc b/lama_cleaner/plugins/__pycache__/base_plugin.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e86908cc3b81210de471ab86f46ea2ea840aa3c9 Binary files /dev/null and b/lama_cleaner/plugins/__pycache__/base_plugin.cpython-38.pyc differ diff --git a/lama_cleaner/plugins/__pycache__/remove_bg.cpython-38.pyc b/lama_cleaner/plugins/__pycache__/remove_bg.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a93e12c7784546c9f2e0468f4be667260e145f16 Binary files /dev/null and b/lama_cleaner/plugins/__pycache__/remove_bg.cpython-38.pyc differ diff --git a/lama_cleaner/plugins/base_plugin.py b/lama_cleaner/plugins/base_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..39d7491f92a3990667a28a927da20a1bcf70198a --- /dev/null +++ b/lama_cleaner/plugins/base_plugin.py @@ -0,0 +1,15 @@ +from loguru import logger + + +class BasePlugin: + def __init__(self): + err_msg = self.check_dep() + if err_msg: + logger.error(err_msg) + exit(-1) + + def __call__(self, rgb_np_img, files, form): + ... + + def check_dep(self): + ... diff --git a/lama_cleaner/plugins/remove_bg.py b/lama_cleaner/plugins/remove_bg.py new file mode 100644 index 0000000000000000000000000000000000000000..8be125caef389f066769063c96f6c8bd376bc78f --- /dev/null +++ b/lama_cleaner/plugins/remove_bg.py @@ -0,0 +1,46 @@ +import os +from typing import Tuple + +import cv2 +import numpy as np +from torch.hub import get_dir +from loguru import logger + +from lama_cleaner.plugins.base_plugin import BasePlugin + + +class RemoveBG(BasePlugin): + name = "RemoveBG" + + def __init__(self): + super().__init__() + from rembg import new_session + + # TODO Update for local development + + hub_dir = get_dir() + # model_dir = os.path.join(hub_dir, "checkpoints") + model_dir = os.getcwd() + # os.environ["U2NET_HOME"] = model_dir + # os.environ["U2NET_HOME"] = os.getcwd() + os.environ["U2NET_HOME"] = '/tmp/' + + logger.info(f"Load remove model from: {model_dir}") + self.session = new_session(model_name="u2net") + + def __call__(self, rgb_np_img, files=None, form=None): + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + return self.forward(bgr_np_img) + + def forward(self, bgr_np_img) -> np.ndarray: + from rembg import remove + output = remove(bgr_np_img, session=self.session) + return output + + def check_dep(self): + try: + import rembg + except ImportError: + return ( + "RemoveBG is not installed, please install it first. pip install rembg" + ) diff --git a/lama_cleaner/runtime.py b/lama_cleaner/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f8a78afb16b8f48f44ec911182a8f04c4191b5 --- /dev/null +++ b/lama_cleaner/runtime.py @@ -0,0 +1,47 @@ +# https://github.com/huggingface/huggingface_hub/blob/5a12851f54bf614be39614034ed3a9031922d297/src/huggingface_hub/utils/_runtime.py +import platform +import sys +from typing import Dict, Any + +import packaging.version +from rich import print + +_PY_VERSION: str = sys.version.split()[0].rstrip("+") + +if packaging.version.Version(_PY_VERSION) < packaging.version.Version("3.8.0"): + import importlib_metadata # type: ignore +else: + import importlib.metadata as importlib_metadata # type: ignore + +_package_versions = {} + +_CANDIDATES = [ + "torch", + "Pillow", + "diffusers", + "transformers", + "opencv-python", + "xformers", + "accelerate", + "lama-cleaner" +] +# Check once at runtime +for name in _CANDIDATES: + _package_versions[name] = "N/A" + try: + _package_versions[name] = importlib_metadata.version(name) + except importlib_metadata.PackageNotFoundError: + pass + + +def dump_environment_info() -> Dict[str, str]: + """Dump information about the machine to help debugging issues. """ + + # Generic machine info + info: Dict[str, Any] = { + "Platform": platform.platform(), + "Python version": platform.python_version(), + } + info.update(_package_versions) + print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n") + return info diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2a3e111cfd40a8928bb1f59bec1d30e8cf8e01 --- /dev/null +++ b/lama_cleaner/schema.py @@ -0,0 +1,90 @@ +from enum import Enum + +from PIL.Image import Image +from pydantic import BaseModel + + +class HDStrategy(str, Enum): + # Use original image size + ORIGINAL = "Original" + # Resize the longer side of the image to a specific size(hd_strategy_resize_limit), + # then do inpainting on the resized image. Finally, resize the inpainting result to the original size. + # The area outside the mask will not lose quality. + RESIZE = "Resize" + # Crop masking area(with a margin controlled by hd_strategy_crop_margin) from the original image to do inpainting + CROP = "Crop" + + +class LDMSampler(str, Enum): + ddim = "ddim" + plms = "plms" + + +class SDSampler(str, Enum): + ddim = "ddim" + pndm = "pndm" + k_lms = "k_lms" + k_euler = 'k_euler' + k_euler_a = 'k_euler_a' + dpm_plus_plus = 'dpm++' + + +class Config(BaseModel): + class Config: + arbitrary_types_allowed = True + + # Configs for ldm model + ldm_steps: int + ldm_sampler: str = LDMSampler.plms + + # Configs for zits model + zits_wireframe: bool = True + + # Configs for High Resolution Strategy(different way to preprocess image) + hd_strategy: str # See HDStrategy Enum + hd_strategy_crop_margin: int + # If the longer side of the image is larger than this value, use crop strategy + hd_strategy_crop_trigger_size: int + hd_strategy_resize_limit: int + + # Configs for Stable Diffusion 1.5 + prompt: str = "" + negative_prompt: str = "" + # Crop image to this size before doing sd inpainting + # The value is always on the original image scale + use_croper: bool = False + croper_x: int = None + croper_y: int = None + croper_height: int = None + croper_width: int = None + + # Resize the image before doing sd inpainting, the area outside the mask will not lose quality. + # Used by sd models and paint_by_example model + sd_scale: float = 1.0 + # Blur the edge of mask area. The higher the number the smoother blend with the original image + sd_mask_blur: int = 0 + # Ignore this value, it's useless for inpainting + sd_strength: float = 0.75 + # The number of denoising steps. More denoising steps usually lead to a + # higher quality image at the expense of slower inference. + sd_steps: int = 50 + # Higher guidance scale encourages to generate images that are closely linked + # to the text prompt, usually at the expense of lower image quality. + sd_guidance_scale: float = 7.5 + sd_sampler: str = SDSampler.ddim + # -1 mean random seed + sd_seed: int = 42 + sd_match_histograms: bool = False + + # Configs for opencv inpainting + # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 + cv2_flag: str = 'INPAINT_NS' + cv2_radius: int = 4 + + # Paint by Example + paint_by_example_steps: int = 50 + paint_by_example_guidance_scale: float = 7.5 + paint_by_example_mask_blur: int = 0 + paint_by_example_seed: int = 42 + paint_by_example_match_histograms: bool = False + paint_by_example_example_image: Image = None diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py new file mode 100644 index 0000000000000000000000000000000000000000..a189029e0e2d2159443cd7188b4b9b57b86d7cf5 --- /dev/null +++ b/lama_cleaner/server.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 + +import imghdr +import io +import json +import logging +import multiprocessing +import os +import random +import time +from pathlib import Path +from typing import Union + +import cv2 +import numpy as np +import torch +from PIL import Image +from loguru import logger +from watchdog.events import FileSystemEventHandler + +from lama_cleaner.file_manager import FileManager +from lama_cleaner.interactive_seg import InteractiveSeg, Click +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import Config + +try: + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) +except: + pass + +from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify + +# Disable ability for Flask to display warning about using a development server in a production environment. +# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 +cli.show_server_banner = lambda *_: None +from flask_cors import CORS + +from lama_cleaner.helper import ( + load_img, + numpy_to_bytes, + resize_max_size, +) + +NUM_THREADS = str(multiprocessing.cpu_count()) + +# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56 +os.environ["KMP_DUPLICATE_LIB_OK"] = "True" + +os.environ["OMP_NUM_THREADS"] = NUM_THREADS +os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS +os.environ["MKL_NUM_THREADS"] = NUM_THREADS +os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS +os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS +if os.environ.get("CACHE_DIR"): + os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] + +BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build") + + +class NoFlaskwebgui(logging.Filter): + def filter(self, record): + return "flaskwebgui-keep-server-alive" not in record.getMessage() + + +logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) + +app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) +app.config["JSON_AS_ASCII"] = False +CORS(app, expose_headers=["Content-Disposition"]) + +model: ModelManager = None +thumb: FileManager = None +interactive_seg_model: InteractiveSeg = None +device = None +input_image_path: str = None +is_disable_model_switch: bool = False +is_enable_file_manager: bool = False +is_desktop: bool = False + + +def get_image_ext(img_bytes): + w = imghdr.what("", img_bytes) + if w is None: + w = "jpeg" + return w + + +def diffuser_callback(i, t, latents): + pass + # socketio.emit('diffusion_step', {'diffusion_step': step}) + + +@app.route("/save_image", methods=["POST"]) +def save_image(): + # all image in output directory + input = request.files + origin_image_bytes = input["image"].read() # RGB + image, _ = load_img(origin_image_bytes) + thumb.save_to_output_directory(image, request.form["filename"]) + return 'ok', 200 + + +@app.route("/medias/") +def medias(tab): + if tab == 'image': + response = make_response(jsonify(thumb.media_names), 200) + else: + response = make_response(jsonify(thumb.output_media_names), 200) + # response.last_modified = thumb.modified_time[tab] + # response.cache_control.no_cache = True + # response.cache_control.max_age = 0 + # response.make_conditional(request) + return response + + +@app.route('/media//') +def media_file(tab, filename): + if tab == 'image': + return send_from_directory(thumb.root_directory, filename) + return send_from_directory(thumb.output_dir, filename) + + +@app.route('/media_thumbnail//') +def media_thumbnail_file(tab, filename): + args = request.args + width = args.get('width') + height = args.get('height') + if width is None and height is None: + width = 256 + if width: + width = int(float(width)) + if height: + height = int(float(height)) + + directory = thumb.root_directory + if tab == 'output': + directory = thumb.output_dir + thumb_filename, (width, height) = thumb.get_thumbnail(directory, filename, width, height) + thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}" + + response = make_response(send_file(thumb_filepath)) + response.headers["X-Width"] = str(width) + response.headers["X-Height"] = str(height) + return response + + +@app.route("/inpaint", methods=["POST"]) +def process(): + print("-------------") + print(request) + logger.info(f"Resized Resized Resized: { request.form}") + + input = request.files + # RGB + origin_image_bytes = input["image"].read() + image, alpha_channel = load_img(origin_image_bytes) + + + mask, _ = load_img(input["mask"].read(), gray=True) + mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] + + if image.shape[:2] != mask.shape[:2]: + return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400 + + original_shape = image.shape + interpolation = cv2.INTER_CUBIC + + form = request.form + size_limit: Union[int, str] = form.get("sizeLimit", "1080") + logger.info(size_limit) + + if size_limit == "Original": + size_limit = max(image.shape) + else: + size_limit = int(size_limit) + + if "paintByExampleImage" in input: + paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read()) + paint_by_example_example_image = Image.fromarray(paint_by_example_example_image) + else: + paint_by_example_example_image = None + + config = Config( + ldm_steps=form["ldmSteps"], + ldm_sampler=form["ldmSampler"], + hd_strategy=form["hdStrategy"], + zits_wireframe=form["zitsWireframe"], + hd_strategy_crop_margin=form["hdStrategyCropMargin"], + hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"], + hd_strategy_resize_limit=form["hdStrategyResizeLimit"], + prompt=form["prompt"], + negative_prompt=form["negativePrompt"], + use_croper=form["useCroper"], + croper_x=form["croperX"], + croper_y=form["croperY"], + croper_height=form["croperHeight"], + croper_width=form["croperWidth"], + sd_scale=form["sdScale"], + sd_mask_blur=form["sdMaskBlur"], + sd_strength=form["sdStrength"], + sd_steps=form["sdSteps"], + sd_guidance_scale=form["sdGuidanceScale"], + sd_sampler=form["sdSampler"], + sd_seed=form["sdSeed"], + sd_match_histograms=form["sdMatchHistograms"], + cv2_flag=form["cv2Flag"], + cv2_radius=form['cv2Radius'], + paint_by_example_steps=form["paintByExampleSteps"], + paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"], + paint_by_example_mask_blur=form["paintByExampleMaskBlur"], + paint_by_example_seed=form["paintByExampleSeed"], + paint_by_example_match_histograms=form["paintByExampleMatchHistograms"], + paint_by_example_example_image=paint_by_example_example_image, + ) + print(form["hdStrategy"]) + + if config.sd_seed == -1: + config.sd_seed = random.randint(1, 999999999) + if config.paint_by_example_seed == -1: + config.paint_by_example_seed = random.randint(1, 999999999) + + logger.info(f"Origin image shape: {original_shape}") + image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) + logger.info(f"Resized image shape: {image.shape}") + + mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) + + start = time.time() + try: + res_np_img = model(image, mask, config) + except RuntimeError as e: + torch.cuda.empty_cache() + if "CUDA out of memory. " in str(e): + # NOTE: the string may change? + return "CUDA out of memory", 500 + else: + logger.exception(e) + return "Internal Server Error", 500 + finally: + logger.info(f"process time: {(time.time() - start) * 1000}ms") + torch.cuda.empty_cache() + + if alpha_channel is not None: + if alpha_channel.shape[:2] != res_np_img.shape[:2]: + alpha_channel = cv2.resize( + alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0]) + ) + res_np_img = np.concatenate( + (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 + ) + + ext = get_image_ext(origin_image_bytes) + + response = make_response( + send_file( + io.BytesIO(numpy_to_bytes(res_np_img, ext)), + mimetype=f"image/{ext}", + ) + ) + response.headers["X-Seed"] = str(config.sd_seed) + return response + + +@app.route("/interactive_seg", methods=["POST"]) +def interactive_seg(): + input = request.files + origin_image_bytes = input["image"].read() # RGB + image, _ = load_img(origin_image_bytes) + if 'mask' in input: + mask, _ = load_img(input["mask"].read(), gray=True) + else: + mask = None + + _clicks = json.loads(request.form["clicks"]) + clicks = [] + for i, click in enumerate(_clicks): + clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)) + + start = time.time() + new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask) + logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms") + response = make_response( + send_file( + io.BytesIO(numpy_to_bytes(new_mask, 'png')), + mimetype=f"image/png", + ) + ) + return response + + +@app.route("/model") +def current_model(): + return model.name, 200 + + +@app.route("/is_disable_model_switch") +def get_is_disable_model_switch(): + res = 'true' if is_disable_model_switch else 'false' + return res, 200 + + +@app.route("/is_enable_file_manager") +def get_is_enable_file_manager(): + res = 'true' if is_enable_file_manager else 'false' + return res, 200 + + +@app.route("/model_downloaded/") +def model_downloaded(name): + return str(model.is_downloaded(name)), 200 + + +@app.route("/is_desktop") +def get_is_desktop(): + return str(is_desktop), 200 + + +@app.route("/model", methods=["POST"]) +def switch_model(): + if is_disable_model_switch: + return "Switch model is disabled", 400 + + new_name = request.form.get("name") + if new_name == model.name: + return "Same model", 200 + + try: + model.switch(new_name) + except NotImplementedError: + return f"{new_name} not implemented", 403 + return f"ok, switch to {new_name}", 200 + + +@app.route("/") +def index(): + return send_file(os.path.join(BUILD_DIR, "index.html"), cache_timeout=0) + + +@app.route("/inputimage") +def set_input_photo(): + if input_image_path: + with open(input_image_path, "rb") as f: + image_in_bytes = f.read() + return send_file( + input_image_path, + as_attachment=True, + attachment_filename=Path(input_image_path).name, + mimetype=f"image/{get_image_ext(image_in_bytes)}", + ) + else: + return "No Input Image" + + +class FSHandler(FileSystemEventHandler): + def on_modified(self, event): + print("File modified: %s" % event.src_path) + + +def main(args): + print("-----------------------------------") + print(args) + global model + global interactive_seg_model + global device + global input_image_path + global is_disable_model_switch + global is_enable_file_manager + global is_desktop + global thumb + + device = torch.device(args.device) + is_disable_model_switch = args.disable_model_switch + is_desktop = args.gui + if is_disable_model_switch: + logger.info(f"Start with --disable-model-switch, model switch on frontend is disable") + + if args.input and os.path.isdir(args.input): + logger.info(f"Initialize file manager") + thumb = FileManager(app) + is_enable_file_manager = True + app.config["THUMBNAIL_MEDIA_ROOT"] = args.input + app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'lama_cleaner_thumbnails') + thumb.output_dir = Path(args.output_dir) + # thumb.start() + # try: + # while True: + # time.sleep(1) + # finally: + # thumb.image_dir_observer.stop() + # thumb.image_dir_observer.join() + # thumb.output_dir_observer.stop() + # thumb.output_dir_observer.join() + + else: + input_image_path = args.input + + model = ModelManager( + name=args.model, + device=device, + no_half=args.no_half, + hf_access_token=args.hf_access_token, + disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw, + sd_cpu_textencoder=args.sd_cpu_textencoder, + sd_run_local=args.sd_run_local, + local_files_only=args.local_files_only, + cpu_offload=args.cpu_offload, + enable_xformers=args.sd_enable_xformers or args.enable_xformers, + callback=diffuser_callback, + ) + + interactive_seg_model = InteractiveSeg() + + if args.gui: + app_width, app_height = args.gui_size + from flaskwebgui import FlaskUI + + ui = FlaskUI( + app, width=app_width, height=app_height, host=args.host, port=args.port, + close_server_on_exit=not args.no_gui_auto_close + ) + ui.run() + else: + app.run(host=args.host, port=args.port, debug=args.debug) diff --git a/lama_cleaner/server2.py b/lama_cleaner/server2.py new file mode 100644 index 0000000000000000000000000000000000000000..a2c5e9bbf98bae39b336193e04dd5b7e9535d625 --- /dev/null +++ b/lama_cleaner/server2.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +import os +import sys + +# import traceback + +__dir__ = os.path.dirname(os.path.abspath(__file__)) + +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) + +import base64 + +import logging +import multiprocessing +import os +import random +import time +import imghdr +from pathlib import Path + +import cv2 +import torch +import numpy as np +from loguru import logger + +from lama_cleaner.interactive_seg import InteractiveSeg +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import Config +from lama_cleaner.file_manager import FileManager +from lama_cleaner.plugins.remove_bg import RemoveBG + +try: + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) +except: + pass + +# Disable ability for Flask to display warning about using a development server in a production environment. +# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 +# cli.show_server_banner = lambda *_: None +# from flask_cors import CORS + +from lama_cleaner.helper import ( + load_img, + resize_max_size, +) + +NUM_THREADS = str(multiprocessing.cpu_count()) + +# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56 +os.environ["KMP_DUPLICATE_LIB_OK"] = "True" + +os.environ["OMP_NUM_THREADS"] = NUM_THREADS +os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS +os.environ["MKL_NUM_THREADS"] = NUM_THREADS +os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS +os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS +if os.environ.get("CACHE_DIR"): + os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] + +BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build") + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[1] # YOLOv5 root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + + +class NoFlaskwebgui(logging.Filter): + def filter(self, record): + return "flaskwebgui-keep-server-alive" not in record.getMessage() + + +logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) + +# app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) +# app.config["JSON_AS_ASCII"] = False +# CORS(app, expose_headers=["Content-Disposition"]) + +model: ModelManager = None +thumb: FileManager = None +device = None +input_image_path: str = None +is_disable_model_switch: bool = False +is_enable_file_manager: bool = False +is_desktop: bool = False +plugins = {} + + +def get_image_ext(img_bytes): + w = imghdr.what("", img_bytes) + if w is None: + w = "jpeg" + return w + + +def diffuser_callback(i, t, latents): + pass + # socketio.emit('diffusion_step', {'diffusion_step': step}) + + +config = Config( + ldm_steps=25, + ldm_sampler='plms', + hd_strategy='Resize', # Original, Resize, Crop + zits_wireframe=True, + hd_strategy_crop_margin=196, + hd_strategy_crop_trigger_size=1280, + hd_strategy_resize_limit=2048, + prompt="", + negative_prompt="", + use_croper=False, + croper_x=None, + croper_y=None, + croper_height=None, + croper_width=None, + sd_scale=1, + sd_mask_blur=5, + sd_strength=0.75, + sd_steps=50, + sd_guidance_scale=7.5, + sd_sampler="pndm", + sd_seed=42, + sd_match_histograms=False, + cv2_flag="INPAINT_NS", + cv2_radius=40, + paint_by_example_steps=50, + paint_by_example_guidance_scale=7.5, + paint_by_example_mask_blur=5, + paint_by_example_seed=42, + paint_by_example_match_histograms=False, + paint_by_example_example_image=None, +) + + +def process(origin_image_bytes, mask): + image, alpha_channel = load_img(origin_image_bytes) + + mask, _ = load_img(mask, gray=True) + mask = np.where(mask > 0, 255, 0).astype(np.uint8) + + if image.shape[:2] != mask.shape[:2]: + return f"Mask shape {mask.shape[:2]} not queal to Image shape {image.shape[:2]}", 400 + + original_shape = image.shape + interpolation = cv2.INTER_CUBIC + + size_limit = 2048 + if size_limit == "Original": + size_limit = max(image.shape) + else: + size_limit = int(size_limit) + + if config.sd_seed == -1: + config.sd_seed = random.randint(1, 999999999) + if config.paint_by_example_seed == -1: + config.paint_by_example_seed = random.randint(1, 999999999) + + logger.info(f"Origin image shape: {original_shape}") + image = resize_max_size(image, size_limit=size_limit, + interpolation=interpolation) + logger.info(f"Resized image shape: {image.shape}") + + mask = resize_max_size(mask, size_limit=size_limit, + interpolation=interpolation) + + start = time.time() + try: + with torch.no_grad(): + res_np_img = model(image, mask, config) + except RuntimeError as e: + torch.cuda.empty_cache() + if "CUDA out of memory. " in str(e): + # NOTE: the string may change? + return "CUDA out of memory", 500 + else: + logger.exception(e) + return "Internal Server Error", 500 + finally: + torch.cuda.empty_cache() + logger.info(f"process time: {(time.time() - start)}s") + + if alpha_channel is not None: + if alpha_channel.shape[:2] != res_np_img.shape[:2]: + alpha_channel = np.resize( + alpha_channel, (res_np_img.shape[1], res_np_img.shape[0]) + ) + res_np_img = np.concatenate( + (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 + ) + + img = cv2.imencode('.jpg', res_np_img)[1] + return base64.b64encode(img).decode('utf-8') + + + +def current_model(): + return model.name, 200 + + +def get_is_disable_model_switch(): + res = 'true' if is_disable_model_switch else 'false' + return res, 200 + + +def switch_model(new_name): + if is_disable_model_switch: + return "Switch model is disabled", 400 + + if new_name == model.name: + return "Same model", 200 + + try: + model.switch(new_name) + except NotImplementedError: + return f"{new_name} not implemented", 403 + return f"ok, switch to {new_name}", 200 + + +def remove(origin_image_bytes): + name = RemoveBG.name + rgb_np_img, alpha_channel = load_img(origin_image_bytes) + + start = time.time() + try: + bgr_res = plugins[name](rgb_np_img) + except RuntimeError as e: + torch.cuda.empty_cache() + if "CUDA out of memory. " in str(e): + return "CUDA out of memory", 500 + else: + logger.exception(e) + return "Internal Server Error", 500 + + logger.info(f"{name} process time: {(time.time() - start) * 1000}ms") + + img = cv2.imencode('.png', bgr_res)[1] + return base64.b64encode(img).decode('utf-8') + + +def initModel(): + global model + global device + global input_image_path + global is_disable_model_switch + global is_enable_file_manager + global is_desktop + global thumb + global plugins + + model_device = "cpu" + device = torch.device(model_device) + is_disable_model_switch = False + is_desktop = False + if is_disable_model_switch: + logger.info( + f"Start with --disable-model-switch, model switch on frontend is disable") + + model = ModelManager(model_device, callback=diffuser_callback) + plugins[RemoveBG.name] = RemoveBG() diff --git a/lama_cleaner/tests/.gitignore b/lama_cleaner/tests/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..82fd7050c53725c2dfe021b94f7980b015cf8124 --- /dev/null +++ b/lama_cleaner/tests/.gitignore @@ -0,0 +1 @@ +*_result.png \ No newline at end of file diff --git a/lama_cleaner/tests/__init__.py b/lama_cleaner/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lama_cleaner/tests/test_interactive_seg.py b/lama_cleaner/tests/test_interactive_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..1802dd46149db978b7f80bbfc2e300117261d20d --- /dev/null +++ b/lama_cleaner/tests/test_interactive_seg.py @@ -0,0 +1,36 @@ +from pathlib import Path + +import cv2 +import numpy as np + +from lama_cleaner.interactive_seg import InteractiveSeg, Click + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / 'result' +save_dir.mkdir(exist_ok=True, parents=True) +img_p = current_dir / "overture-creations-5sI6fQgYIuo.png" + + +def test_interactive_seg(): + interactive_seg_model = InteractiveSeg() + img = cv2.imread(str(img_p)) + pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)]) + cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred) + + +def test_interactive_seg_with_negative_click(): + interactive_seg_model = InteractiveSeg() + img = cv2.imread(str(img_p)) + pred = interactive_seg_model(img, clicks=[ + Click(coords=(256, 256), indx=0, is_positive=True), + Click(coords=(384, 256), indx=1, is_positive=False) + ]) + cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred) + + +def test_interactive_seg_with_prev_mask(): + interactive_seg_model = InteractiveSeg() + img = cv2.imread(str(img_p)) + mask = np.zeros_like(img)[:, :, 0] + pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask) + cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred) diff --git a/lama_cleaner/tests/test_load_img.py b/lama_cleaner/tests/test_load_img.py new file mode 100644 index 0000000000000000000000000000000000000000..6028a60d8f3cb8c27959c68041b621d7819c4635 --- /dev/null +++ b/lama_cleaner/tests/test_load_img.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from lama_cleaner.helper import load_img + +current_dir = Path(__file__).parent.absolute().resolve() +png_img_p = current_dir / "image.png" +jpg_img_p = current_dir / "bunny.jpeg" + + +def test_load_png_image(): + with open(png_img_p, "rb") as f: + np_img, alpha_channel = load_img(f.read()) + assert np_img.shape == (256, 256, 3) + assert alpha_channel.shape == (256, 256) + + +def test_load_jpg_image(): + with open(jpg_img_p, "rb") as f: + np_img, alpha_channel = load_img(f.read()) + assert np_img.shape == (394, 448, 3) + assert alpha_channel is None diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5ef95790371c77a6e1c97b90d503802080d97a --- /dev/null +++ b/lama_cleaner/tests/test_model.py @@ -0,0 +1,196 @@ +from pathlib import Path + +import cv2 +import pytest +import torch + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import Config, HDStrategy, LDMSampler + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / 'result' +save_dir.mkdir(exist_ok=True, parents=True) +device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = torch.device(device) + + +def get_data(fx: float = 1, fy: float = 1.0, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"): + img = cv2.imread(str(img_p)) + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) + mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) + img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) + mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST) + return img, mask + + +def get_config(strategy, **kwargs): + data = dict( + ldm_steps=1, + ldm_sampler=LDMSampler.plms, + hd_strategy=strategy, + hd_strategy_crop_margin=32, + hd_strategy_crop_trigger_size=200, + hd_strategy_resize_limit=200, + ) + data.update(**kwargs) + return Config(**data) + + +def assert_equal(model, config, gt_name, + fx: float = 1, fy: float = 1, + img_p=current_dir / "image.png", + mask_p=current_dir / "mask.png"): + img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) + print(f"Input image shape: {img.shape}") + res = model(img, mask, config) + cv2.imwrite( + str(save_dir / gt_name), + res, + [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], + ) + + """ + Note that JPEG is lossy compression, so even if it is the highest quality 100, + when the saved images is reloaded, a difference occurs with the original pixel value. + If you want to save the original images as it is, save it as PNG or BMP. + """ + # gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED) + # assert np.array_equal(res, gt) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +def test_lama(strategy): + model = ModelManager(name="lama", device=device) + assert_equal( + model, + get_config(strategy), + f"lama_{strategy[0].upper() + strategy[1:]}_result.png", + ) + + fx = 1.3 + assert_equal( + model, + get_config(strategy), + f"lama_{strategy[0].upper() + strategy[1:]}_fx_{fx}_result.png", + fx=1.3, + ) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +@pytest.mark.parametrize("ldm_sampler", [LDMSampler.ddim, LDMSampler.plms]) +def test_ldm(strategy, ldm_sampler): + model = ModelManager(name="ldm", device=device) + cfg = get_config(strategy, ldm_sampler=ldm_sampler) + assert_equal( + model, cfg, f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_result.png" + ) + + fx = 1.3 + assert_equal( + model, + cfg, + f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_fx_{fx}_result.png", + fx=fx, + ) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +@pytest.mark.parametrize("zits_wireframe", [False, True]) +def test_zits(strategy, zits_wireframe): + model = ModelManager(name="zits", device=device) + cfg = get_config(strategy, zits_wireframe=zits_wireframe) + # os.environ['ZITS_DEBUG_LINE_PATH'] = str(current_dir / 'zits_debug_line.jpg') + # os.environ['ZITS_DEBUG_EDGE_PATH'] = str(current_dir / 'zits_debug_edge.jpg') + assert_equal( + model, + cfg, + f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_result.png", + ) + + fx = 1.3 + assert_equal( + model, + cfg, + f"zits_{strategy.capitalize()}_wireframe_{zits_wireframe}_fx_{fx}_result.png", + fx=fx, + ) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL] +) +def test_mat(strategy): + model = ModelManager(name="mat", device=device) + cfg = get_config(strategy) + + assert_equal( + model, + cfg, + f"mat_{strategy.capitalize()}_result.png", + ) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL] +) +def test_fcf(strategy): + model = ModelManager(name="fcf", device=device) + cfg = get_config(strategy) + + assert_equal( + model, + cfg, + f"fcf_{strategy.capitalize()}_result.png", + fx=2, + fy=2 + ) + + assert_equal( + model, + cfg, + f"fcf_{strategy.capitalize()}_result.png", + fx=3.8, + fy=2 + ) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +@pytest.mark.parametrize("cv2_flag", ['INPAINT_NS', 'INPAINT_TELEA']) +@pytest.mark.parametrize("cv2_radius", [3, 15]) +def test_cv2(strategy, cv2_flag, cv2_radius): + model = ModelManager( + name="cv2", + device=torch.device(device), + ) + cfg = get_config(strategy, cv2_flag=cv2_flag, cv2_radius=cv2_radius) + assert_equal( + model, + cfg, + f"sd_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]) +def test_manga(strategy): + model = ModelManager( + name="manga", + device=torch.device(device), + ) + cfg = get_config(strategy) + assert_equal( + model, + cfg, + f"sd_{strategy.capitalize()}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/lama_cleaner/tests/test_paint_by_example.py b/lama_cleaner/tests/test_paint_by_example.py new file mode 100644 index 0000000000000000000000000000000000000000..c495690c9f3a22ac79b97177e6be26a2c68114a1 --- /dev/null +++ b/lama_cleaner/tests/test_paint_by_example.py @@ -0,0 +1,106 @@ +from pathlib import Path + +import cv2 +import pytest +import torch +from PIL import Image + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import HDStrategy +from lama_cleaner.tests.test_model import get_config, get_data + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / 'result' +save_dir.mkdir(exist_ok=True, parents=True) +device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = torch.device(device) + + +def assert_equal( + model, config, gt_name, + fx: float = 1, fy: float = 1, + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + example_p=current_dir / "bunny.jpeg", +): + img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) + + example_image = cv2.imread(str(example_p)) + example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB) + example_image = cv2.resize(example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) + + print(f"Input image shape: {img.shape}, example_image: {example_image.shape}") + config.paint_by_example_example_image = Image.fromarray(example_image) + res = model(img, mask, config) + cv2.imwrite(str(save_dir / gt_name), res) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example(strategy): + model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True) + cfg = get_config(strategy, paint_by_example_steps=30) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fy=0.9, + fx=1.3, + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example_disable_nsfw(strategy): + model = ModelManager(name="paint_by_example", device=device, disable_nsfw=False) + cfg = get_config(strategy, paint_by_example_steps=30) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}_disable_nsfw.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example_sd_scale(strategy): + model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True) + cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}_sdscale.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fy=0.9, + fx=1.3 + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example_cpu_offload(strategy): + model = ModelManager(name="paint_by_example", device=device, cpu_offload=True, disable_nsfw=False) + cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}_cpu_offload.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example_cpu_offload_cpu_device(strategy): + model = ModelManager(name="paint_by_example", device=torch.device('cpu'), cpu_offload=True, disable_nsfw=True) + cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}_cpu_offload_cpu_device.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fy=0.9, + fx=1.3 + ) diff --git a/lama_cleaner/tests/test_sd_model.py b/lama_cleaner/tests/test_sd_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d643b2824da022dc282266a0a3ef506a6b0752ac --- /dev/null +++ b/lama_cleaner/tests/test_sd_model.py @@ -0,0 +1,208 @@ +from pathlib import Path + +import pytest +import torch + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import HDStrategy, SDSampler +from lama_cleaner.tests.test_model import get_config, assert_equal + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / 'result' +save_dir.mkdir(exist_ok=True, parents=True) +device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = torch.device(device) + + +@pytest.mark.parametrize("sd_device", ['cuda']) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +@pytest.mark.parametrize("cpu_textencoder", [True, False]) +@pytest.mark.parametrize("disable_nsfw", [True, False]) +def test_runway_sd_1_5_ddim(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): + def callback(i, t, latents): + print(f"sd_step_{i}") + + if sd_device == 'cuda' and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == 'cuda' else 1 + model = ModelManager(name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=cpu_textencoder, + callback=callback) + cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.3 + ) + + +@pytest.mark.parametrize("sd_device", ['cuda']) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.pndm, SDSampler.k_lms, SDSampler.k_euler, SDSampler.k_euler_a]) +@pytest.mark.parametrize("cpu_textencoder", [False]) +@pytest.mark.parametrize("disable_nsfw", [True]) +def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): + def callback(i, t, latents): + print(f"sd_step_{i}") + + if sd_device == 'cuda' and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == 'cuda' else 1 + model = ModelManager(name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=cpu_textencoder, + callback=callback) + cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.3 + ) + + +@pytest.mark.parametrize("sd_device", ['cuda']) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler): + def callback(i, t, latents): + pass + + if sd_device == 'cuda' and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == 'cuda' else 1 + model = ModelManager(name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=False, + sd_cpu_textencoder=False, + callback=callback) + cfg = get_config( + strategy, + sd_steps=sd_steps, + prompt='Face of a fox, high resolution, sitting on a park bench', + negative_prompt='orange, yellow, small', + sd_sampler=sampler, + sd_match_histograms=True + ) + + name = f"{sampler}_negative_prompt" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1 + ) + + +@pytest.mark.parametrize("sd_device", ['cuda']) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a]) +@pytest.mark.parametrize("cpu_textencoder", [False]) +@pytest.mark.parametrize("disable_nsfw", [False]) +def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): + if sd_device == 'cuda' and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == 'cuda' else 1 + model = ModelManager(name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=cpu_textencoder) + cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}_sdscale.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.3 + ) + + +@pytest.mark.parametrize("sd_device", ['cuda']) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a]) +def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler): + if sd_device == 'cuda' and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == 'cuda' else 1 + model = ModelManager(name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True) + cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("sd_device", ['cpu']) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a]) +def test_runway_sd_1_5_cpu_offload_cpu_device(sd_device, strategy, sampler): + model = ModelManager(name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=False, + sd_cpu_textencoder=False, + cpu_offload=True) + cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=1, sd_scale=0.85) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload_cpu_device.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/lama_cleaner/web_config.py b/lama_cleaner/web_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f9730d2fdc770ca0cf52b24664932647f438d8 --- /dev/null +++ b/lama_cleaner/web_config.py @@ -0,0 +1,118 @@ +import json +import os +from datetime import datetime + +import gradio as gr +from loguru import logger +from pydantic import BaseModel + +from lama_cleaner.const import AVAILABLE_MODELS, AVAILABLE_DEVICES, CPU_OFFLOAD_HELP, NO_HALF_HELP, DISABLE_NSFW_HELP, \ + SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, OUTPUT_DIR_HELP, INPUT_HELP, \ + GUI_HELP, DEFAULT_MODEL, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR + +_config_file = None + + +class Config(BaseModel): + host: str = "127.0.0.1" + port: int = 8080 + model: str = DEFAULT_MODEL + device: str = DEFAULT_DEVICE + gui: bool = False + no_gui_auto_close: bool = False + no_half: bool = False + cpu_offload: bool = False + disable_nsfw: bool = False + sd_cpu_textencoder: bool = False + enable_xformers: bool = False + local_files_only: bool = False + model_dir: str = DEFAULT_MODEL_DIR + input: str = None + output_dir: str = None + + +def load_config(installer_config: str): + if os.path.exists(installer_config): + with open(installer_config, "r", encoding='utf-8') as f: + return Config(**json.load(f)) + else: + return Config() + + +def save_config( + host, port, model, device, gui, no_gui_auto_close, no_half, cpu_offload, + disable_nsfw, sd_cpu_textencoder, enable_xformers, local_files_only, + model_dir, input, output_dir +): + config = Config(**locals()) + print(config) + if config.input and not os.path.exists(config.input): + return "[Error] Input file or directory does not exist" + + current_time = datetime.now().strftime("%H:%M:%S") + msg = f"[{current_time}] Successful save config to: {os.path.abspath(_config_file)}" + logger.info(msg) + try: + with open(_config_file, "w", encoding="utf-8") as f: + json.dump(config.dict(), f, indent=4, ensure_ascii=False) + except Exception as e: + return f"Save failed: {str(e)}" + return msg + + +def close_server(*args): + # TODO: make close both browser and server works + import os, signal + pid = os.getpid() + os.kill(pid, signal.SIGUSR1) + + +def main(config_file: str): + global _config_file + _config_file = config_file + + init_config = load_config(config_file) + + with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(scale=1): + save_btn = gr.Button(value="Save configurations") + message = gr.HTML() + # with gr.Column(scale=0, min_width=100): + # exit_btn = gr.Button(value="Close") + # exit_btn.click(close_server) + with gr.Row(): + host = gr.Textbox(init_config.host, label="Host") + port = gr.Number(init_config.port, label="Port", precision=0) + with gr.Row(): + model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model) + device = gr.Radio(AVAILABLE_DEVICES, label="Device", value=init_config.device) + gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}") + no_gui_auto_close = gr.Checkbox(init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}") + no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}") + cpu_offload = gr.Checkbox(init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}") + disable_nsfw = gr.Checkbox(init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}") + sd_cpu_textencoder = gr.Checkbox(init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}") + enable_xformers = gr.Checkbox(init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}") + local_files_only = gr.Checkbox(init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}") + model_dir = gr.Textbox(init_config.model_dir, label=f"{MODEL_DIR_HELP}") + input = gr.Textbox(init_config.input, label=f"Input file or directory. {INPUT_HELP}") + output_dir = gr.Textbox(init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}") + save_btn.click(save_config, [ + host, + port, + model, + device, + gui, + no_gui_auto_close, + no_half, + cpu_offload, + disable_nsfw, + sd_cpu_textencoder, + enable_xformers, + local_files_only, + model_dir, + input, + output_dir, + ], message) + demo.launch(inbrowser=True, show_api=False) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3e00b7f22b0fdc2161592f90517b6cb7632dc5d9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +opencv-python-headless +numpy==1.24.2 +setuptools==60.2.0 +ultralytics==8.0.54 +Pillow>=9.4.0 +tqdm>=4.65.0 +packaging>=23.0 +loguru==0.6.0 +rich==13.3.2 +pydantic +pytest +yacs +markupsafe +scikit-image==0.19.3 +diffusers[torch]==0.12.1 +transformers>=4.25.1 +watchdog==2.2.1 +gradio +piexif==1.1.3 +safetensors +python-dotenv +fastapi +uvicorn[standard] +torch +torchvision +mangum +gunicorn +rembg \ No newline at end of file diff --git a/yolov8x-seg.pt b/yolov8x-seg.pt new file mode 100644 index 0000000000000000000000000000000000000000..32ec037daef545f7d3858a80aa52d5159999757d --- /dev/null +++ b/yolov8x-seg.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d63cbfa5764867c0066bedfa43cf2dcd90a412a1de44b2e238c43978a9d28ea6 +size 144076467