Spaces:
Paused
Paused
# -*- coding: utf-8 -*- | |
""" | |
# File name: hair_editor.py | |
# Time : 2021/11/18 17:21 | |
# Author: [email protected] | |
# Description: | |
""" | |
import os | |
import pickle | |
from glob import glob | |
import cv2 | |
import numpy as np | |
import torch | |
import my_torchlib | |
from color_texture_branch.solver import Solver as SolveFeature | |
from external_code.face_parsing.my_parsing_util import FaceParsing | |
from global_value_utils import HAIR_IDX, PARSING_LABEL_LIST | |
from poisson_blending import poisson_blending | |
from sean_codes.models.pix2pix_model import Pix2PixModel | |
from sean_codes.options.test_options import TestOptions | |
from shape_branch.solver import Solver as SolverMask | |
from util.imutil import write_rgb | |
# adaptor_root_dir = '/data1/guoxuyang/myWorkSpace/hair_editing' | |
# sys.path.append(adaptor_root_dir) | |
# sys.path.append(os.path.join(adaptor_root_dir, 'external_code/face_3DDFA')) | |
def change_status(model, new_status): | |
for m in model.modules(): | |
if hasattr(m, 'status'): | |
m.status = new_status | |
class HairEditor: | |
""" | |
This is the basic module, that could achieve many editing task. ui/hair_swap.py/Backend succeed this class. | |
""" | |
def __init__(self, load_feature_model=True, load_mask_model=True): | |
self.opt = TestOptions().parse() | |
self.opt.status = 'test' | |
self.sean_model = Pix2PixModel(self.opt) | |
self.sean_model.eval() | |
self.img_size = 256 | |
self.device = torch.device('cuda', 0) | |
if load_feature_model: | |
from color_texture_branch.config import cfg as cfg_feature | |
self.solver_feature = SolveFeature(cfg_feature, device=self.device, local_rank=-1, training=False) | |
self.feature_encoder = self.solver_feature.dis | |
self.feature_generator = self.solver_feature.gen | |
self.feature_rgb_predictor = self.solver_feature.rgb_model | |
# self.feature_curliness_predictor = self.solver_feature.curliness_model | |
# ckpt_dir = 'external_model_params/disentangle_checkpoints/' + cfg_app.experiment_name + '/checkpoints' | |
ckpt_dir = 'model_trained/color_texture/' + cfg_feature.experiment_name + '/checkpoints' | |
ckpt = my_torchlib.load_checkpoint(ckpt_dir) | |
for model_name in ['Model_G', 'Model_D']: | |
cur_model = ckpt[model_name] | |
if list(cur_model)[0].startswith('module'): | |
ckpt[model_name] = {kk[7:]: cur_model[kk] for kk in cur_model} | |
self.feature_generator.load_state_dict(ckpt['Model_G'], strict=True) | |
self.feature_encoder.load_state_dict(ckpt['Model_D'], strict=True) | |
# if 'curliness' in cfg_feature.predictor: | |
# ckpt = my_torchlib.load_checkpoint(cfg_feature.predictor.curliness.root_dir + '/checkpoints') | |
# self.feature_curliness_predictor.load_state_dict(ckpt['Predictor'], strict=True) | |
if 'rgb' in cfg_feature.predictor: | |
ckpt = my_torchlib.load_checkpoint(cfg_feature.predictor.rgb.root_dir + '/checkpoints') | |
self.feature_rgb_predictor.load_state_dict(ckpt['Predictor'], strict=True) | |
# load unsupervised direction | |
existing_dirs_dir = os.path.join('model_trained/color_texture', cfg_feature.experiment_name, | |
'texture_dir_used') | |
if os.path.exists(existing_dirs_dir): | |
existing_dirs_list = os.listdir(existing_dirs_dir) | |
existing_dirs_list.sort() | |
existing_dirs = [] | |
for dd in existing_dirs_list: | |
with open(os.path.join(existing_dirs_dir, dd), 'rb') as f: | |
existing_dirs.append(pickle.load(f).to(self.device)) | |
self.texture_dirs = existing_dirs | |
if load_mask_model: | |
from shape_branch.config import cfg as cfg_mask | |
self.solver_mask = SolverMask(cfg_mask, device=self.device, local_rank=-1, training=False) | |
self.mask_generator = self.solver_mask.gen | |
############################################## | |
# change to your checkpoints dir # | |
############################################## | |
ckpt_dir = 'model_trained/shape/' + cfg_mask.experiment_name + '/checkpoints' | |
ckpt = my_torchlib.load_checkpoint(ckpt_dir) | |
for model_name in ['Model_G', 'Model_D']: | |
cur_model = ckpt[model_name] | |
if list(cur_model)[0].startswith('module'): | |
ckpt[model_name] = {kk[7:]: cur_model[kk] for kk in cur_model} | |
self.mask_generator.load_state_dict(ckpt['Model_G'], strict=True) | |
# load unsupervised direction | |
existing_dirs_dir = os.path.join('model_trained/shape', cfg_mask.experiment_name, 'shape_dir_used') | |
if os.path.exists(existing_dirs_dir): | |
existing_dirs_list = os.listdir(existing_dirs_dir) | |
existing_dirs_list.sort() | |
existing_dirs = [] | |
for dd in existing_dirs_list: | |
with open(os.path.join(existing_dirs_dir, dd), 'rb') as f: | |
existing_dirs.append(pickle.load(f).to(self.device)) | |
self.shape_dirs = existing_dirs | |
def preprocess_img(self, img): | |
img = cv2.resize(img.astype('uint8'), (self.img_size, self.img_size)) | |
return (np.transpose(img, [2, 0, 1]) / 127.5 - 1.0)[None, ...] | |
def preprocess_mask(self, mask_img): | |
mask_img = cv2.resize(mask_img.astype('uint8'), (self.img_size, self.img_size), | |
interpolation=cv2.INTER_NEAREST) | |
return mask_img[None, None, :, :] | |
def load_average_feature(): | |
############### load average features | |
# average_style_code_folder = 'styles_test/mean_style_code/mean/' | |
average_style_code_folder = 'sean_codes/styles_test/mean_style_code/median/' | |
input_style_dic = {} | |
############### hard coding for categories | |
for i in range(19): | |
input_style_dic[str(i)] = {} | |
average_category_folder_list = glob(os.path.join(average_style_code_folder, str(i), '*.npy')) | |
average_category_list = [os.path.splitext(os.path.basename(name))[0] for name in | |
average_category_folder_list] | |
for style_code_path in average_category_list: | |
input_style_dic[str(i)][style_code_path] = torch.from_numpy( | |
np.load(os.path.join(average_style_code_folder, str(i), style_code_path + '.npy'))).cuda() | |
return input_style_dic | |
def get_code(self, hair_img, hair_parsing): | |
# generate style code | |
data = {'label': torch.tensor(hair_parsing, dtype=torch.float32), | |
'instance': torch.tensor(0), | |
'image': torch.tensor(hair_img, dtype=torch.float32), | |
'path': ['temp/temp_npy']} | |
change_status(self.sean_model, 'test') | |
hair_img_code = self.sean_model(data, mode='style_code') | |
return hair_img_code | |
def gen_img(self, code, parsing): | |
# load style code | |
if not isinstance(code, torch.Tensor): | |
code = torch.tensor(code) | |
obj_dic = self.load_average_feature() | |
for idx in range(19): | |
cur_code = code[0, idx] | |
if not torch.all(cur_code == 0): | |
obj_dic[str(idx)]['ACE'] = cur_code | |
temp_face_image = torch.zeros((0, 3, self.img_size, self.img_size)) # place holder | |
data = {'label': torch.tensor(parsing, dtype=torch.float32), | |
'instance': torch.tensor(0), | |
'image': torch.tensor(temp_face_image, dtype=torch.float32), | |
'obj_dic': obj_dic} | |
change_status(self.sean_model, 'UI_mode') | |
# self.model = self.model.to(code.device) | |
generated = self.sean_model(data, mode='UI_mode')[0] | |
return generated | |
def generate_by_sean(self, face_img_code, hair_code, target_seg): | |
""" | |
:param face_img_code: please input with the shape [19, 512] | |
:param hair_code: please input with the shape [512] | |
:param target_seg: | |
:return: | |
""" | |
# load style code | |
obj_dic = self.load_average_feature() | |
for idx in range(19): | |
if idx == HAIR_IDX: | |
cur_code = hair_code | |
# cur_code = face_img_code[0, idx] | |
else: | |
cur_code = face_img_code[idx] | |
if not torch.all(face_img_code == 0): | |
obj_dic[str(idx)]['ACE'] = cur_code | |
data = {'label': torch.tensor(target_seg, dtype=torch.float32), | |
'instance': torch.tensor(0), | |
'obj_dic': obj_dic, | |
'image': None} | |
change_status(self.sean_model, 'UI_mode') | |
generated = self.sean_model(data, mode='UI_mode')[0] | |
return generated | |
def generate_instance_transfer_img(self, face_img, face_parsing, hair_img, hair_parsing, target_seg, edit_data=None, | |
temp_path='temp'): | |
# generate style code | |
data = {'label': torch.tensor(face_parsing, dtype=torch.float32), | |
'instance': torch.tensor(0), | |
'image': torch.tensor(face_img, dtype=torch.float32), | |
'path': ['temp/temp_npy']} | |
face_img_code = self.sean_model(data, mode='style_code') | |
if hair_img is None: | |
hair_img_code = face_img_code | |
else: | |
data = {'label': torch.tensor(hair_parsing, dtype=torch.float32), | |
'instance': torch.tensor(0), | |
'image': torch.tensor(hair_img, dtype=torch.float32), | |
'path': ['temp/temp_npy']} | |
change_status(self.sean_model, 'test') | |
hair_img_code = self.sean_model(data, mode='style_code') | |
hair_code = hair_img_code[0, HAIR_IDX] | |
if edit_data is not None: | |
hair_code = self.solver_feature.edit_infer(hair_code[None, ...], edit_data)[0] | |
return self.generate_by_sean(face_img_code[0], hair_code, target_seg) | |
def get_hair_color(self, img): | |
parsing, _ = FaceParsing.parsing_img(img) | |
parsing = FaceParsing.swap_parsing_label_to_celeba_mask(parsing) | |
parsing = cv2.resize(parsing.astype('uint8'), (1024, 1024), interpolation=cv2.INTER_NEAREST) | |
img = cv2.resize(img.astype('uint8'), (1024, 1024)) | |
hair_mask = (parsing == HAIR_IDX).astype('uint8') | |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize=(19, 19)) | |
hair_mask = cv2.erode(hair_mask, kernel, iterations=1) | |
points = img[hair_mask.astype('bool')] | |
moment1 = points.mean(axis=0) | |
return moment1 | |
def draw_landmarks(img, lms): | |
lms = lms / 2 | |
lms = lms.astype('int') | |
for idx, point in enumerate(lms): | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
pos = (point[0], point[1]) | |
cv2.circle(img, pos, 2, color=(139, 0, 0)) | |
cv2.putText(img, str(idx + 1), pos, font, 0.18, (255, 0, 0), 1, cv2.LINE_AA) | |
return img | |
def postprocess_blending(self, face_img, res_img, face_parsing, target_parsing, verbose_print=False, blending=True): | |
""" | |
Blend original face img and result image with poisson blending. | |
If not blend, the result image will look slightly different from original image in some details in | |
non-hair region, but the image quality will be better. | |
:param face_img: | |
:param res_img: | |
:param face_parsing: | |
:param target_parsing: | |
:param verbose_print: | |
:param blending: If `False`, the result image will do some trivial thing like transferring data type | |
:return: | |
""" | |
if verbose_print: | |
print("Post process for the result image...") | |
def from_tensor_order_to_cv2(tensor_img, is_mask=False): | |
if isinstance(tensor_img, torch.Tensor): | |
tensor_img = tensor_img.detach().cpu().numpy() | |
if len(tensor_img.shape) == 4: | |
tensor_img = tensor_img[0] | |
if len(tensor_img.shape) == 2: | |
tensor_img = tensor_img[None, ...] | |
if tensor_img.shape[2] <= 3: | |
return tensor_img | |
res = np.transpose(tensor_img, [1, 2, 0]) | |
if not is_mask: | |
res = res * 127.5 + 127.5 | |
return res | |
res_img = from_tensor_order_to_cv2(res_img) | |
res_img = res_img.astype('uint8') | |
if blending: | |
target_parsing = from_tensor_order_to_cv2(target_parsing, is_mask=True) | |
face_img = from_tensor_order_to_cv2(face_img) | |
face_img = face_img.astype('uint8') | |
face_parsing = from_tensor_order_to_cv2(face_parsing, is_mask=True) | |
res_mask = np.logical_or(target_parsing == HAIR_IDX, face_parsing == HAIR_IDX).astype('uint8') | |
kernel13 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize=(13, 13)) | |
kernel5 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize=(5, 5)) | |
res_mask_dilated = cv2.dilate(res_mask, kernel13, iterations=1)[..., None] | |
res_mask_dilated5 = cv2.dilate(res_mask, kernel5, iterations=1)[..., None] | |
bg_mask = (target_parsing == PARSING_LABEL_LIST.index('background')) | |
res_mask_dilated = res_mask_dilated * (1 - bg_mask) + res_mask_dilated5 * bg_mask | |
face_to_hair = poisson_blending(face_img, res_img, 1 - res_mask_dilated, with_gamma=True) | |
return face_to_hair, res_mask_dilated | |
else: | |
return res_img, None | |
def crop_face(self, img_rgb, save_path=None): | |
""" | |
crop the face part in the image to adapt the editing system | |
:param img_rgb: | |
:param save_path: | |
:return: | |
""" | |
from external_code.crop import recreate_aligned_images | |
from external_code.landmarks_util import predictor_dict, detector | |
predictor_68 = predictor_dict[68] | |
bbox = detector(img_rgb, 0)[0] | |
lm_68 = np.array([[p.x, p.y] for p in predictor_68(img_rgb, bbox).parts()]) | |
crop_img_pil, lm_68 = recreate_aligned_images(img_rgb, lm_68, output_size=self.img_size) | |
img_rgb = np.array(crop_img_pil) | |
if save_path is not None: | |
write_rgb(save_path, img_rgb) | |
return img_rgb | |
def get_mask(self, img_rgb): | |
parsing, _ = FaceParsing.parsing_img(img_rgb) | |
parsing = FaceParsing.swap_parsing_label_to_celeba_mask(parsing) | |
mask_img = cv2.resize(parsing.astype('uint8'), (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST) | |
return mask_img | |