File size: 3,946 Bytes
64dbbc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import cv2
import tensorflow as tf
import numpy as np
from source.facelib.facer import FaceAna
import source.utils as utils
from source.mtcnn_pytorch.src.align_trans import warp_and_crop_face, get_reference_facial_points

if tf.__version__ >= '2.0':
    tf = tf.compat.v1
    tf.disable_eager_execution()


class Cartoonizer():
    def __init__(self, dataroot):

        self.facer = FaceAna(dataroot)
        self.sess_head = self.load_sess(
            os.path.join(dataroot, 'cartoon_anime_h.pb'), 'model_head')
        self.sess_bg = self.load_sess(
            os.path.join(dataroot, 'cartoon_anime_bg.pb'), 'model_bg')

        self.box_width = 288
        global_mask = cv2.imread(os.path.join(dataroot, 'alpha.jpg'))
        global_mask = cv2.resize(
            global_mask, (self.box_width, self.box_width),
            interpolation=cv2.INTER_AREA)
        self.global_mask = cv2.cvtColor(
            global_mask, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0

    def load_sess(self, model_path, name):
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        print(f'loading model from {model_path}')
        with tf.gfile.FastGFile(model_path, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name=name)
            sess.run(tf.global_variables_initializer())
        print(f'load model {model_path} done.')
        return sess


    def detect_face(self, img):
        src_h, src_w, _ = img.shape
        src_x = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        boxes, landmarks, _ = self.facer.run(src_x)
        if boxes.shape[0] == 0:
            return None
        else:
            return landmarks


    def cartoonize(self, img):
        # img: RGB input
        ori_h, ori_w, _ = img.shape
        img = utils.resize_size(img, size=720)

        img_brg = img[:, :, ::-1]

        # background process
        pad_bg, pad_h, pad_w = utils.padTo16x(img_brg)

        bg_res = self.sess_bg.run(
            self.sess_bg.graph.get_tensor_by_name(
                'model_bg/output_image:0'),
            feed_dict={'model_bg/input_image:0': pad_bg})
        res = bg_res[:pad_h, :pad_w, :]

        landmarks = self.detect_face(img_brg)
        if landmarks is None:
            print('No face detected!')
            return res

        print('%d faces detected!'%len(landmarks))
        for landmark in landmarks:
            # get facial 5 points
            f5p = utils.get_f5p(landmark, img_brg)

            # face alignment
            head_img, trans_inv = warp_and_crop_face(
                img,
                f5p,
                ratio=0.75,
                reference_pts=get_reference_facial_points(default_square=True),
                crop_size=(self.box_width, self.box_width),
                return_trans_inv=True)

            # head process
            head_res = self.sess_head.run(
                self.sess_head.graph.get_tensor_by_name(
                    'model_head/output_image:0'),
                feed_dict={
                    'model_head/input_image:0': head_img[:, :, ::-1]
                })

            # merge head and background
            head_trans_inv = cv2.warpAffine(
                head_res,
                trans_inv, (np.size(img, 1), np.size(img, 0)),
                borderValue=(0, 0, 0))

            mask = self.global_mask
            mask_trans_inv = cv2.warpAffine(
                mask,
                trans_inv, (np.size(img, 1), np.size(img, 0)),
                borderValue=(0, 0, 0))
            mask_trans_inv = np.expand_dims(mask_trans_inv, 2)

            res = mask_trans_inv * head_trans_inv + (1 - mask_trans_inv) * res

        res = cv2.resize(res, (ori_w, ori_h), interpolation=cv2.INTER_AREA)

        return res