minchul commited on
Commit
aaadf18
·
verified ·
1 Parent(s): ae8f75f

Upload directory

Browse files
aligners/differentiable_face_aligner/dfa/preprocessor.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ class Preprocessor():
5
+
6
+ def __init__(self, output_size=160, padding=0.0, padding_val='zero'):
7
+ self.output_size = output_size
8
+ self.padding = padding
9
+ self.padding_val = padding_val
10
+
11
+ def preprocess_batched(self, imgs, padding_ratio_override=None):
12
+
13
+ # check img is of float
14
+ if imgs.dtype == torch.float32:
15
+ if self.padding_val == 'zero':
16
+ padding_val = -1.0
17
+ elif self.padding_val == 'mean':
18
+ padding_val = imgs.mean()
19
+ else:
20
+ raise ValueError('padding_val must be "zero" or "mean"')
21
+ elif imgs.dtype == torch.uint8:
22
+ if self.padding_val == 'zero':
23
+ padding_val = 0
24
+ elif self.padding_val == 'mean':
25
+ padding_val = imgs.mean()
26
+ else:
27
+ raise ValueError('padding_val must be "zero" or "mean"')
28
+ else:
29
+ raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
30
+
31
+ square_imgs = self.make_square_img_batched(imgs, padding_val=padding_val)
32
+
33
+ if padding_ratio_override is not None:
34
+ padding = padding_ratio_override
35
+ else:
36
+ padding = self.padding
37
+ padded_imgs = self.make_padded_img_batched(square_imgs, padding=padding, padding_val=padding_val)
38
+
39
+ size=(self.output_size, self.output_size)
40
+ if imgs.dtype == torch.float32:
41
+ resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
42
+ elif imgs.dtype == torch.uint8:
43
+ padded_imgs = padded_imgs.to(torch.float32)
44
+ resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
45
+ resized_imgs = torch.clip(resized_imgs, 0, 255)
46
+ resized_imgs = resized_imgs.to(torch.uint8)
47
+ else:
48
+ raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
49
+ return resized_imgs
50
+
51
+
52
+ def make_square_img_batched(self, imgs, padding_val):
53
+ assert imgs.ndim == 4
54
+ # squarify the image
55
+ h, w = imgs.shape[2:]
56
+ if h > w:
57
+ diff = (h - w)
58
+ pad_left = diff // 2
59
+ pad_right = diff - pad_left
60
+ imgs = F.pad(imgs, (pad_left, pad_right, 0, 0), value=padding_val)
61
+ elif w > h:
62
+ diff = (w - h)
63
+ pad_top = diff // 2
64
+ pad_bottom = diff - pad_top
65
+ imgs = F.pad(imgs, (0, 0, pad_top, pad_bottom), value=padding_val)
66
+ assert imgs.shape[2] == imgs.shape[3]
67
+ return imgs
68
+
69
+
70
+ def make_padded_img_batched(self, imgs, padding, padding_val):
71
+ if padding == 0:
72
+ return imgs
73
+ assert imgs.ndim == 4
74
+
75
+
76
+ # pad the image
77
+ h, w = imgs.shape[2:]
78
+ pad_h = int(h * padding)
79
+ pad_w = int(w * padding)
80
+ imgs = F.pad(imgs, (pad_w, pad_w, pad_h, pad_h), value=padding_val)
81
+ return imgs
82
+
83
+
84
+ def __call__(self, input, padding_ratio_override=None):
85
+ if input.ndim == 3:
86
+ assert input.shape[0] == 3
87
+ batch_input = input.unsqueeze(0)
88
+ return self.preprocess_batched(batch_input, padding_ratio_override=padding_ratio_override)[0]
89
+ elif input.ndim == 4:
90
+ assert input.shape[1] == 3
91
+ return self.preprocess_batched(input, padding_ratio_override=padding_ratio_override)
92
+ else:
93
+ raise ValueError(f'Invalid input shape: {input.shape}')