Upload directory
Browse files
aligners/differentiable_face_aligner/dfa/preprocessor.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
class Preprocessor():
|
5 |
+
|
6 |
+
def __init__(self, output_size=160, padding=0.0, padding_val='zero'):
|
7 |
+
self.output_size = output_size
|
8 |
+
self.padding = padding
|
9 |
+
self.padding_val = padding_val
|
10 |
+
|
11 |
+
def preprocess_batched(self, imgs, padding_ratio_override=None):
|
12 |
+
|
13 |
+
# check img is of float
|
14 |
+
if imgs.dtype == torch.float32:
|
15 |
+
if self.padding_val == 'zero':
|
16 |
+
padding_val = -1.0
|
17 |
+
elif self.padding_val == 'mean':
|
18 |
+
padding_val = imgs.mean()
|
19 |
+
else:
|
20 |
+
raise ValueError('padding_val must be "zero" or "mean"')
|
21 |
+
elif imgs.dtype == torch.uint8:
|
22 |
+
if self.padding_val == 'zero':
|
23 |
+
padding_val = 0
|
24 |
+
elif self.padding_val == 'mean':
|
25 |
+
padding_val = imgs.mean()
|
26 |
+
else:
|
27 |
+
raise ValueError('padding_val must be "zero" or "mean"')
|
28 |
+
else:
|
29 |
+
raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
|
30 |
+
|
31 |
+
square_imgs = self.make_square_img_batched(imgs, padding_val=padding_val)
|
32 |
+
|
33 |
+
if padding_ratio_override is not None:
|
34 |
+
padding = padding_ratio_override
|
35 |
+
else:
|
36 |
+
padding = self.padding
|
37 |
+
padded_imgs = self.make_padded_img_batched(square_imgs, padding=padding, padding_val=padding_val)
|
38 |
+
|
39 |
+
size=(self.output_size, self.output_size)
|
40 |
+
if imgs.dtype == torch.float32:
|
41 |
+
resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
|
42 |
+
elif imgs.dtype == torch.uint8:
|
43 |
+
padded_imgs = padded_imgs.to(torch.float32)
|
44 |
+
resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
|
45 |
+
resized_imgs = torch.clip(resized_imgs, 0, 255)
|
46 |
+
resized_imgs = resized_imgs.to(torch.uint8)
|
47 |
+
else:
|
48 |
+
raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
|
49 |
+
return resized_imgs
|
50 |
+
|
51 |
+
|
52 |
+
def make_square_img_batched(self, imgs, padding_val):
|
53 |
+
assert imgs.ndim == 4
|
54 |
+
# squarify the image
|
55 |
+
h, w = imgs.shape[2:]
|
56 |
+
if h > w:
|
57 |
+
diff = (h - w)
|
58 |
+
pad_left = diff // 2
|
59 |
+
pad_right = diff - pad_left
|
60 |
+
imgs = F.pad(imgs, (pad_left, pad_right, 0, 0), value=padding_val)
|
61 |
+
elif w > h:
|
62 |
+
diff = (w - h)
|
63 |
+
pad_top = diff // 2
|
64 |
+
pad_bottom = diff - pad_top
|
65 |
+
imgs = F.pad(imgs, (0, 0, pad_top, pad_bottom), value=padding_val)
|
66 |
+
assert imgs.shape[2] == imgs.shape[3]
|
67 |
+
return imgs
|
68 |
+
|
69 |
+
|
70 |
+
def make_padded_img_batched(self, imgs, padding, padding_val):
|
71 |
+
if padding == 0:
|
72 |
+
return imgs
|
73 |
+
assert imgs.ndim == 4
|
74 |
+
|
75 |
+
|
76 |
+
# pad the image
|
77 |
+
h, w = imgs.shape[2:]
|
78 |
+
pad_h = int(h * padding)
|
79 |
+
pad_w = int(w * padding)
|
80 |
+
imgs = F.pad(imgs, (pad_w, pad_w, pad_h, pad_h), value=padding_val)
|
81 |
+
return imgs
|
82 |
+
|
83 |
+
|
84 |
+
def __call__(self, input, padding_ratio_override=None):
|
85 |
+
if input.ndim == 3:
|
86 |
+
assert input.shape[0] == 3
|
87 |
+
batch_input = input.unsqueeze(0)
|
88 |
+
return self.preprocess_batched(batch_input, padding_ratio_override=padding_ratio_override)[0]
|
89 |
+
elif input.ndim == 4:
|
90 |
+
assert input.shape[1] == 3
|
91 |
+
return self.preprocess_batched(input, padding_ratio_override=padding_ratio_override)
|
92 |
+
else:
|
93 |
+
raise ValueError(f'Invalid input shape: {input.shape}')
|