Spaces:
Running
on
Zero
Running
on
Zero
| import sys | |
| import cv2 | |
| import utils | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from utils import convert_state_dict | |
| from models import restormer_arch | |
| from data.preprocess.crop_merge_image import stride_integral | |
| sys.path.append("./data/MBD/") | |
| from data.MBD.infer import net1_net2_infer_single_im | |
| def dewarp_prompt(img): | |
| mask = net1_net2_infer_single_im(img, "data/MBD/checkpoint/mbd.pkl") | |
| base_coord = utils.getBasecoord(256, 256) / 256 | |
| img[mask == 0] = 0 | |
| mask = cv2.resize(mask, (256, 256)) / 255 | |
| return img, np.concatenate((base_coord, np.expand_dims(mask, -1)), -1) | |
| def deshadow_prompt(img): | |
| h, w = img.shape[:2] | |
| # img = cv2.resize(img,(128,128)) | |
| img = cv2.resize(img, (1024, 1024)) | |
| rgb_planes = cv2.split(img) | |
| result_planes = [] | |
| result_norm_planes = [] | |
| bg_imgs = [] | |
| for plane in rgb_planes: | |
| dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8)) | |
| bg_img = cv2.medianBlur(dilated_img, 21) | |
| bg_imgs.append(bg_img) | |
| diff_img = 255 - cv2.absdiff(plane, bg_img) | |
| norm_img = cv2.normalize( | |
| diff_img, | |
| None, | |
| alpha=0, | |
| beta=255, | |
| norm_type=cv2.NORM_MINMAX, | |
| dtype=cv2.CV_8UC1, | |
| ) | |
| result_planes.append(diff_img) | |
| result_norm_planes.append(norm_img) | |
| bg_imgs = cv2.merge(bg_imgs) | |
| bg_imgs = cv2.resize(bg_imgs, (w, h)) | |
| # result = cv2.merge(result_planes) | |
| result_norm = cv2.merge(result_norm_planes) | |
| result_norm[result_norm == 0] = 1 | |
| shadow_map = np.clip( | |
| img.astype(float) / result_norm.astype(float) * 255, 0, 255 | |
| ).astype(np.uint8) | |
| shadow_map = cv2.resize(shadow_map, (w, h)) | |
| shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_BGR2GRAY) | |
| shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_GRAY2BGR) | |
| # return shadow_map | |
| return bg_imgs | |
| def deblur_prompt(img): | |
| x = cv2.Sobel(img, cv2.CV_16S, 1, 0) | |
| y = cv2.Sobel(img, cv2.CV_16S, 0, 1) | |
| absX = cv2.convertScaleAbs(x) # 转回uint8 | |
| absY = cv2.convertScaleAbs(y) | |
| high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0) | |
| high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY) | |
| high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_GRAY2BGR) | |
| return high_frequency | |
| def appearance_prompt(img): | |
| h, w = img.shape[:2] | |
| # img = cv2.resize(img,(128,128)) | |
| img = cv2.resize(img, (1024, 1024)) | |
| rgb_planes = cv2.split(img) | |
| result_planes = [] | |
| result_norm_planes = [] | |
| for plane in rgb_planes: | |
| dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8)) | |
| bg_img = cv2.medianBlur(dilated_img, 21) | |
| diff_img = 255 - cv2.absdiff(plane, bg_img) | |
| norm_img = cv2.normalize( | |
| diff_img, | |
| None, | |
| alpha=0, | |
| beta=255, | |
| norm_type=cv2.NORM_MINMAX, | |
| dtype=cv2.CV_8UC1, | |
| ) | |
| result_planes.append(diff_img) | |
| result_norm_planes.append(norm_img) | |
| result_norm = cv2.merge(result_norm_planes) | |
| result_norm = cv2.resize(result_norm, (w, h)) | |
| return result_norm | |
| def binarization_promptv2(img): | |
| result, thresh = utils.SauvolaModBinarization(img) | |
| thresh = thresh.astype(np.uint8) | |
| result[result > 155] = 255 | |
| result[result <= 155] = 0 | |
| x = cv2.Sobel(img, cv2.CV_16S, 1, 0) | |
| y = cv2.Sobel(img, cv2.CV_16S, 0, 1) | |
| absX = cv2.convertScaleAbs(x) # 转回uint8 | |
| absY = cv2.convertScaleAbs(y) | |
| high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0) | |
| high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY) | |
| return np.concatenate( | |
| ( | |
| np.expand_dims(thresh, -1), | |
| np.expand_dims(high_frequency, -1), | |
| np.expand_dims(result, -1), | |
| ), | |
| -1, | |
| ) | |
| def dewarping(model, im_org, device): | |
| INPUT_SIZE = 256 | |
| im_masked, prompt_org = dewarp_prompt(im_org.copy()) | |
| h, w = im_masked.shape[:2] | |
| im_masked = im_masked.copy() | |
| im_masked = cv2.resize(im_masked, (INPUT_SIZE, INPUT_SIZE)) | |
| im_masked = im_masked / 255.0 | |
| im_masked = torch.from_numpy(im_masked.transpose(2, 0, 1)).unsqueeze(0) | |
| im_masked = im_masked.float().to(device) | |
| prompt = torch.from_numpy(prompt_org.transpose(2, 0, 1)).unsqueeze(0) | |
| prompt = prompt.float().to(device) | |
| in_im = torch.cat((im_masked, prompt), dim=1) | |
| # inference | |
| base_coord = utils.getBasecoord(INPUT_SIZE, INPUT_SIZE) / INPUT_SIZE | |
| model = model.float() | |
| with torch.no_grad(): | |
| pred = model(in_im) | |
| pred = pred[0][:2].permute(1, 2, 0).cpu().numpy() | |
| pred = pred + base_coord | |
| ## smooth | |
| for i in range(15): | |
| pred = cv2.blur(pred, (3, 3), borderType=cv2.BORDER_REPLICATE) | |
| pred = cv2.resize(pred, (w, h)) * (w, h) | |
| pred = pred.astype(np.float32) | |
| out_im = cv2.remap(im_org, pred[:, :, 0], pred[:, :, 1], cv2.INTER_LINEAR) | |
| prompt_org = (prompt_org * 255).astype(np.uint8) | |
| prompt_org = cv2.resize(prompt_org, im_org.shape[:2][::-1]) | |
| return prompt_org[:, :, 0], prompt_org[:, :, 1], prompt_org[:, :, 2], out_im | |
| def appearance(model, im_org, device): | |
| MAX_SIZE = 1600 | |
| # obtain im and prompt | |
| h, w = im_org.shape[:2] | |
| prompt = appearance_prompt(im_org) | |
| in_im = np.concatenate((im_org, prompt), -1) | |
| # constrain the max resolution | |
| if max(w, h) < MAX_SIZE: | |
| in_im, padding_h, padding_w = stride_integral(in_im, 8) | |
| else: | |
| in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE)) | |
| # normalize | |
| in_im = in_im / 255.0 | |
| in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) | |
| # inference | |
| in_im = in_im.half().to(device) | |
| model = model.half() | |
| with torch.no_grad(): | |
| pred = model(in_im) | |
| pred = torch.clamp(pred, 0, 1) | |
| pred = pred[0].permute(1, 2, 0).cpu().numpy() | |
| pred = (pred * 255).astype(np.uint8) | |
| if max(w, h) < MAX_SIZE: | |
| out_im = pred[padding_h:, padding_w:] | |
| else: | |
| pred[pred == 0] = 1 | |
| shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype( | |
| float | |
| ) / pred.astype(float) | |
| shadow_map = cv2.resize(shadow_map, (w, h)) | |
| shadow_map[shadow_map == 0] = 0.00001 | |
| out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8) | |
| return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im | |
| def deshadowing(model, im_org, device): | |
| MAX_SIZE = 1600 | |
| # obtain im and prompt | |
| h, w = im_org.shape[:2] | |
| prompt = deshadow_prompt(im_org) | |
| in_im = np.concatenate((im_org, prompt), -1) | |
| # constrain the max resolution | |
| if max(w, h) < MAX_SIZE: | |
| in_im, padding_h, padding_w = stride_integral(in_im, 8) | |
| else: | |
| in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE)) | |
| # normalize | |
| in_im = in_im / 255.0 | |
| in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) | |
| # inference | |
| in_im = in_im.half().to(device) | |
| model = model.half() | |
| with torch.no_grad(): | |
| pred = model(in_im) | |
| pred = torch.clamp(pred, 0, 1) | |
| pred = pred[0].permute(1, 2, 0).cpu().numpy() | |
| pred = (pred * 255).astype(np.uint8) | |
| if max(w, h) < MAX_SIZE: | |
| out_im = pred[padding_h:, padding_w:] | |
| else: | |
| pred[pred == 0] = 1 | |
| shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype( | |
| float | |
| ) / pred.astype(float) | |
| shadow_map = cv2.resize(shadow_map, (w, h)) | |
| shadow_map[shadow_map == 0] = 0.00001 | |
| out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8) | |
| return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im | |
| def deblurring(model, im_org, device): | |
| # setup image | |
| in_im, padding_h, padding_w = stride_integral(im_org, 8) | |
| prompt = deblur_prompt(in_im) | |
| in_im = np.concatenate((in_im, prompt), -1) | |
| in_im = in_im / 255.0 | |
| in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) | |
| in_im = in_im.half().to(device) | |
| # inference | |
| model.to(device) | |
| model.eval() | |
| model = model.half() | |
| with torch.no_grad(): | |
| pred = model(in_im) | |
| pred = torch.clamp(pred, 0, 1) | |
| pred = pred[0].permute(1, 2, 0).cpu().numpy() | |
| pred = (pred * 255).astype(np.uint8) | |
| out_im = pred[padding_h:, padding_w:] | |
| return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im | |
| def binarization(model, im_org, device): | |
| im, padding_h, padding_w = stride_integral(im_org, 8) | |
| prompt = binarization_promptv2(im) | |
| h, w = im.shape[:2] | |
| in_im = np.concatenate((im, prompt), -1) | |
| in_im = in_im / 255.0 | |
| in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) | |
| in_im = in_im.to(device) | |
| model = model.half() | |
| in_im = in_im.half() | |
| with torch.no_grad(): | |
| pred = model(in_im) | |
| pred = pred[:, :2, :, :] | |
| pred = torch.max(torch.softmax(pred, 1), 1)[1] | |
| pred = pred[0].cpu().numpy() | |
| pred = (pred * 255).astype(np.uint8) | |
| pred = cv2.resize(pred, (w, h)) | |
| out_im = pred[padding_h:, padding_w:] | |
| return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im | |
| def model_init(model_path, device): | |
| # prepare model | |
| model = restormer_arch.Restormer( | |
| inp_channels=6, | |
| out_channels=3, | |
| dim=48, | |
| num_blocks=[2, 3, 3, 4], | |
| num_refinement_blocks=4, | |
| heads=[1, 2, 4, 8], | |
| ffn_expansion_factor=2.66, | |
| bias=False, | |
| LayerNorm_type="WithBias", | |
| dual_pixel_task=True, | |
| ) | |
| if device == "cpu": | |
| state = convert_state_dict( | |
| torch.load(model_path, map_location="cpu")["model_state"] | |
| ) | |
| else: | |
| state = convert_state_dict( | |
| torch.load(model_path, map_location="cuda:0")["model_state"] | |
| ) | |
| model.load_state_dict(state) | |
| model.eval() | |
| model = model.to(device) | |
| return model | |
| def resize(image, max_size): | |
| h, w = image.shape[:2] | |
| if max(h, w) > max_size: | |
| if h > w: | |
| h_new = max_size | |
| w_new = int(w * h_new / h) | |
| else: | |
| w_new = max_size | |
| h_new = int(h * w_new / w) | |
| pil_image = Image.fromarray(image) | |
| pil_image = pil_image.resize((w_new, h_new), Image.Resampling.LANCZOS) | |
| image = np.array(pil_image) | |
| return image | |
| def inference_one_image(model, image, tasks, device): | |
| # image should be in BGR format | |
| if "dewarping" in tasks: | |
| *_, image = dewarping(model, image, device) | |
| # if only dewarping return here | |
| if len(tasks) == 1 and "dewarping" in tasks: | |
| return image | |
| image = resize(image, 1536) | |
| if "deshadowing" in tasks: | |
| *_, image = deshadowing(model, image, device) | |
| if "appearance" in tasks: | |
| *_, image = appearance(model, image, device) | |
| if "deblurring" in tasks: | |
| *_, image = deblurring(model, image, device) | |
| if "binarization" in tasks: | |
| *_, image = binarization(model, image, device) | |
| return image | |