File size: 2,045 Bytes
033bd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.misc import normalize


class build_lut_transform(nn.Module):

    def __init__(self, input_dim, lut_dim, input_resolution, opt):
        super().__init__()

        self.lut_dim = lut_dim
        self.opt = opt

        # self.compress_layer = nn.Linear(input_resolution, 1)

        self.transform_layers = nn.Sequential(
            nn.Linear(input_dim, 3 * lut_dim ** 3, bias=True),
            # nn.BatchNorm1d(3 * lut_dim ** 3, affine=False),
            nn.ReLU(inplace=True),
            nn.Linear(3 * lut_dim ** 3, 3 * lut_dim ** 3, bias=True),
        )
        self.transform_layers[-1].apply(lambda m: hyper_weight_init(m))

    def forward(self, composite_image, fg_appearance_features, bg_appearance_features):
        composite_image = normalize(composite_image, self.opt, 'inv')

        features = fg_appearance_features

        lut_params = self.transform_layers(features)

        fit_3DLUT = lut_params.view(lut_params.shape[0], 3, self.lut_dim, self.lut_dim, self.lut_dim)

        lut_transform_image = torch.stack(
            [TrilinearInterpolation(lut, image)[0] for lut, image in zip(fit_3DLUT, composite_image)], dim=0)

        return fit_3DLUT, normalize(lut_transform_image, self.opt)


def TrilinearInterpolation(LUT, img):
    img = (img - 0.5) * 2.

    img = img.unsqueeze(0).permute(0, 2, 3, 1)[:, None].flip(-1)

    # Note that the coordinates in the grid_sample are inverse to LUT DHW, i.e., xyz is to WHD not DHW.
    LUT = LUT[None]

    # grid sample
    result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True)

    # drop added dimensions and permute back
    result = result[:, :, 0]

    return result


def hyper_weight_init(m):
    if hasattr(m, 'weight'):
        nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
        m.weight.data = m.weight.data / 1.e2

    if hasattr(m, 'bias'):
        with torch.no_grad():
            m.bias.uniform_(0., 1.)