Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import imageio | |
| import numpy as np | |
| import glob | |
| import sys | |
| from typing import Any | |
| sys.path.insert(1, '.') | |
| import argparse | |
| from pytorch_lightning import seed_everything | |
| from PIL import Image | |
| import torch | |
| from operators import GaussialBlurOperator | |
| from utils import get_rank | |
| from torchvision.ops import masks_to_boxes | |
| from matfusion import MateralDiffusion | |
| from loguru import logger | |
| __MAX_BATCH__ = 4 # 4 for A10 | |
| def init_model(ckpt_path, ddim, gpu_id): | |
| # find config | |
| configs = os.listdir(f'{ckpt_path}/configs') | |
| model_config = [config for config in configs if "project.yaml" in config][0] | |
| sds_loss_class = MateralDiffusion(device=gpu_id, fp16=True, | |
| config=f'{ckpt_path}/configs/{model_config}', | |
| ckpt=f'{ckpt_path}/checkpoints/last.ckpt', vram_O=False, | |
| t_range=[0.001, 0.02], opt=None, use_ddim=ddim) | |
| return sds_loss_class | |
| def images_spliter(image, seg_h, seg_w, padding_pixel, padding_val, overlaps=1): | |
| # split the input images along height and weidth by | |
| # return a list of images | |
| h, w, c = image.shape | |
| h = h - (h%(seg_h*overlaps)) | |
| w = w - (w%(seg_w*overlaps)) | |
| h_crop = h // seg_h | |
| w_crop = w // seg_w | |
| images = [] | |
| positions = [] | |
| img_padded = torch.zeros(h+padding_pixel*2, w+padding_pixel*2, 3, device=image.device) + padding_val | |
| img_padded[padding_pixel:h+padding_pixel, padding_pixel:w+padding_pixel, :] = image[:h, :w] | |
| # overlapped sampling | |
| seg_h = np.round((h - h_crop) / h_crop * overlaps).astype(int) + 1 | |
| seg_w = np.round((w - w_crop) / w_crop * overlaps).astype(int) + 1 | |
| h_step = np.round(h_crop / overlaps).astype(int) | |
| w_step = np.round(w_crop / overlaps).astype(int) | |
| # print(f"h_step: {h_step}, seg_h: {seg_h}, w_step: {w_step}, seg_w: {seg_w}, img_padded: {img_padded.shape}, image[:h, :w]: {image[:h, :w].shape}") | |
| for ind_i in range(0,seg_h): | |
| i = ind_i * h_step | |
| for ind_j in range(0,seg_w): | |
| j = ind_j * w_step | |
| img_ = img_padded[i:i+h_crop+padding_pixel*2, j:j+w_crop+padding_pixel*2, :] | |
| images.append(img_) | |
| positions.append(torch.FloatTensor([i-padding_pixel, j-padding_pixel]).reshape(2)) | |
| return torch.stack(images, dim=0), torch.stack(positions, dim=0), seg_h, seg_w | |
| class InferenceModel(): | |
| def __init__(self, ckpt_path, use_ddim, gpu_id=0): | |
| self.model = init_model(ckpt_path, use_ddim, gpu_id=gpu_id) | |
| self.gpu_id = gpu_id | |
| self.split_hw = [1,1] | |
| self.padding = 0 | |
| self.padding_crop = 0 | |
| self.results_list = None | |
| self.results_output_list = [] | |
| self.image_sizes_list = [] | |
| def parse_item(self, img_ori, mask_img_ori, guid_images): | |
| # if mask_img_ori is None: | |
| # mask_img_ori = read_img(input_name, read_alpha=True) | |
| # # ensure background is white, same as training data | |
| # img_ori[~(mask_img_ori[..., 0] > 0.5)] = 1 | |
| img_ori[~(mask_img_ori[..., 0] > 0.5)] = 1 | |
| use_true_mask = (self.split_hw[0] * self.split_hw[1]) <= 1 | |
| self.ori_hw = list(img_ori.shape) | |
| # mask cropping | |
| min_max_uv = masks_to_boxes(mask_img_ori[None, ..., -1] > 0.5).long() | |
| self.min_uv, self.max_uv = min_max_uv[0, ..., [1,0]], min_max_uv[0, ..., [3,2]]+1 | |
| # print(self.min_uv, self.max_uv) | |
| mask_img = mask_img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] | |
| img = img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] | |
| image_size = list(img.shape) | |
| if not use_true_mask: | |
| # for cropping boarder | |
| self.max_uv[0] = self.max_uv[0] - ((self.max_uv[0]-self.min_uv[0])%(self.split_hw[0]*self.split_overlap)) | |
| self.max_uv[1] = self.max_uv[1] - ((self.max_uv[1]-self.min_uv[1])%(self.split_hw[1]*self.split_overlap)) | |
| mask_img = mask_img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] | |
| img = img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] | |
| image_size = list(img.shape) | |
| if not use_true_mask: | |
| mask_img = torch.ones_like(mask_img) | |
| mask_img, _ = images_spliter(mask_img[..., [0, 0, 0]], self.split_hw[0], self.split_hw[1], self.padding, not use_true_mask, self.split_overlap)[:2] | |
| img, position_indexes, seg_h, seg_w = images_spliter(img, self.split_hw[0], self.split_hw[1], self.padding, 1, self.split_overlap) | |
| self.split_hw_overlapped = [seg_h, seg_w] | |
| logger.info(f"Spliting Size: {image_size}, splits: {self.split_hw}, Overlapped: {self.split_hw_overlapped}") | |
| if guid_images is None: | |
| guid_images = torch.zeros_like(img) | |
| else: | |
| guid_images = guid_images[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] | |
| guid_images, _ = images_spliter(guid_images, self.split_hw[0], self.split_hw[1], self.padding, 1, self.split_overlap)[:2] | |
| return guid_images, img, mask_img[..., :1], image_size, position_indexes | |
| def prepare_batch(self, guid_img, img_ori, mask_img_ori, batch_size): | |
| input_img = [] | |
| cond_img = [] | |
| mask_img = [] | |
| image_size = [] | |
| position_indexes = [] | |
| for i in range(batch_size): | |
| _input_img, _cond_img, _mask_img, _image_size, _position_indexes = \ | |
| self.parse_item(img_ori, mask_img_ori, guid_img) | |
| input_img.append(_input_img) | |
| cond_img.append(_cond_img) | |
| mask_img.append(_mask_img) | |
| position_indexes.append(_position_indexes) | |
| image_size += [_image_size] * _input_img.shape[0] | |
| input_img = torch.cat(input_img, dim=0).to(self.gpu_id) | |
| cond_img = torch.cat(cond_img, dim=0).to(self.gpu_id) | |
| mask_img = torch.cat(mask_img, dim=0).to(self.gpu_id) | |
| position_indexes = torch.cat(position_indexes, dim=0).to(self.gpu_id) | |
| return input_img, cond_img, mask_img, image_size, position_indexes | |
| def assemble_results(self, img_out, img_hw=None, position_index=None, default_val=1): | |
| results_img = np.zeros((img_hw[0], img_hw[1], 3)) | |
| weight_img = np.zeros((img_hw[0], img_hw[1], 3)) + 1e-5 | |
| for i in range(position_index.shape[0]): | |
| # crop out boarder | |
| crop_h, crop_w = img_out[i].shape[:2] | |
| pathed_img = img_out[i][self.padding_crop:crop_h-self.padding_crop, self.padding_crop:crop_w-self.padding_crop] | |
| position_index[i] += self.padding_crop | |
| crop_h, crop_w = pathed_img.shape[:2] | |
| crop_x, crop_y = max(position_index[i][0], 0), max(position_index[i][1], 0) | |
| shape_max = results_img[crop_x:crop_x+crop_h, crop_y:crop_y+crop_w].shape[:2] | |
| start_crop_x, start_crop_y = abs(min(position_index[i][0], 0)), abs(min(position_index[i][1], 0)) | |
| # print(pathed_img[start_crop_x:shape_max[0], start_crop_y:shape_max[1]].shape, crop_x, crop_y, position_index[i]) | |
| results_img[crop_x:crop_x+shape_max[0]-start_crop_x, crop_y:crop_y+shape_max[1]-start_crop_y] += pathed_img[start_crop_x:shape_max[0], start_crop_y:shape_max[1]] | |
| weight_img[crop_x:crop_x+crop_h-start_crop_x, crop_y:crop_y+shape_max[1]-start_crop_y] += 1 | |
| img_out = results_img / weight_img | |
| img_out[weight_img[:,:,0] < 1] = 255 | |
| # print(img_out.shape, weight_img.shape, np.unique(weight_img), pathed_img.dtype) | |
| img_out_ = (np.zeros((self.ori_hw[0], self.ori_hw[1], 3)) + default_val) * 255 | |
| img_out_[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] = img_out | |
| img_out = img_out_ | |
| return img_out | |
| def write_batch_img(self, imgs, image_sizes, position_indexes): | |
| cropped_batch = self.split_hw_overlapped[0] * self.split_hw_overlapped[1] | |
| if self.results_list is None or self.results_list.shape[0] == 0: | |
| self.results_list = imgs | |
| self.position_indexes = position_indexes | |
| else: | |
| self.results_list = torch.cat([self.results_list, imgs], dim=0) | |
| self.position_indexes = torch.cat([self.position_indexes, position_indexes], dim=0) | |
| self.image_sizes_list += image_sizes | |
| valid_len = self.results_list.shape[0] - (self.results_list.shape[0] % cropped_batch) | |
| out_images = [] | |
| for ind in range(0, valid_len, cropped_batch): | |
| # assemble results | |
| img_out = (self.results_list[ind:ind+cropped_batch].detach().cpu().numpy() * 255).astype(np.uint8) | |
| img_out = self.assemble_results(img_out, self.image_sizes_list[ind], self.position_indexes[ind:ind+cropped_batch].detach().cpu().numpy().astype(int)) | |
| # Image.fromarray(img_out.astype(np.uint8)).save(self.results_output_list[ind]) | |
| out_images.append(img_out.astype(np.uint8)) | |
| self.results_list = self.results_list[valid_len:] | |
| self.position_indexes = self.position_indexes[valid_len:] | |
| self.image_sizes_list = self.image_sizes_list[valid_len:] | |
| return out_images | |
| def write_batch_input(self, imgs, image_sizes, position_indexes, default_val=1): | |
| cropped_batch = self.split_hw_overlapped[0] * self.split_hw_overlapped[1] | |
| images = [] | |
| valid_len = imgs.shape[0] | |
| for ind in range(0, valid_len, cropped_batch): | |
| # assemble results | |
| img_out = (imgs[ind:ind+cropped_batch].detach().cpu().numpy() * 255).astype(np.uint8) | |
| img_out = self.assemble_results(img_out, image_sizes[ind], position_indexes.detach().cpu().numpy().astype(int), default_val).astype(np.uint8) | |
| images.append(img_out) | |
| return images | |
| def generation(self, split_hw, split_overlap, guid_img, img_ori, mask_img_ori, dps_scale, uc_score, ddim_steps, batch_size=32, n_samples=1): | |
| max_batch = __MAX_BATCH__ | |
| operator = GaussialBlurOperator(61, 3.0, self.gpu_id) | |
| assert batch_size == 1 | |
| self.split_resolution = None | |
| self.split_overlap = split_overlap | |
| self.split_hw = split_hw | |
| # get img hw | |
| for src_img_id in range(0, 1, batch_size): | |
| input_img, cond_img, mask_img, image_sizes, position_indexes = self.prepare_batch(guid_img, img_ori, mask_img_ori, 1) | |
| input_masked = self.write_batch_input(cond_img, image_sizes, position_indexes) | |
| input_maskes = self.write_batch_input(mask_img, image_sizes, position_indexes, 0) | |
| results_all = [] | |
| for _ in range(n_samples): | |
| for batch_id in range(0, input_img.shape[0], max_batch): | |
| embeddings = {} | |
| embeddings["cond_img"] = cond_img[batch_id:batch_id+max_batch] | |
| if (mask_img[batch_id:batch_id+max_batch] > 0.5).sum() == 0: | |
| results = torch.ones_like(cond_img[batch_id:batch_id+max_batch]) | |
| else: | |
| results = self.model(embeddings, input_img[batch_id:batch_id+max_batch], mask_img[batch_id:batch_id+max_batch], ddim_steps=ddim_steps, | |
| guidance_scale=uc_score, dps_scale=dps_scale, as_latent=False, grad_scale=1, operator=operator) | |
| out_images = self.write_batch_img(results, image_sizes[batch_id:batch_id+max_batch], position_indexes[batch_id:batch_id+max_batch]) | |
| results_all += out_images | |
| ret = { | |
| "input_image": input_masked, | |
| "input_maskes": input_maskes, | |
| "out_images": results_all | |
| } | |
| return ret | |