File size: 2,257 Bytes
91d9343
 
 
8d1740c
28ac920
 
 
 
91d9343
28ac920
 
 
 
 
 
 
 
 
 
91d9343
 
 
 
814e69a
91d9343
89e4ae0
91d9343
a9077eb
a3814f8
 
d3ca146
91d9343
a9077eb
91d9343
06894c7
d3ca146
91d9343
 
 
28ac920
 
814e69a
28ac920
 
 
 
 
 
246dd82
28ac920
91d9343
06894c7
28ac920
 
a9077eb
91d9343
246dd82
 
a3814f8
28ac920
a9077eb
 
28ac920
 
814e69a
3b42de6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.optim as optim
import torch.nn.functional as F

def gram_matrix(feature):
    b, c, h, w = feature.size()
    feature = feature.view(b * c, h * w)
    return feature @ feature.t()

def compute_loss(generated, content, style, bg_masks, alpha, beta):
    content_loss = sum(F.mse_loss(gf, cf) for gf, cf in zip(generated, content))
    style_loss = sum(
        F.mse_loss(
            gram_matrix(gf * bg) if bg is not None else gram_matrix(gf),
            gram_matrix(sf * bg) if bg is not None else gram_matrix(sf),
        ) / len(generated)
        for gf, sf, bg in zip(generated, style, bg_masks or [None] * len(generated))
    )
    return alpha * content_loss, beta * style_loss, alpha * content_loss + beta * style_loss

def inference(
    *,
    model,
    sod_model,
    content_image,
    content_image_norm,
    style_features,
    apply_to_background,
    lr=1.5e-2,
    iterations=51,
    optim_caller=optim.AdamW,
    alpha=1,
    beta=1,
):
    generated_image = content_image.clone().requires_grad_(True)
    optimizer = optim_caller([generated_image], lr=lr)

    with torch.no_grad():
        content_features = model(content_image)
        bg_masks = None
        
        if apply_to_background:
            seg_output = torch.sigmoid(sod_model(content_image_norm)[0])
            bg_mask = (seg_output <= 0.7).float()
            bg_masks = [
                F.interpolate(bg_mask.unsqueeze(1), size=cf.shape[2:], mode='bilinear', align_corners=False)
                for cf in content_features
            ]
        
    def closure():
        optimizer.zero_grad()
        generated_features = model(generated_image)
        content_loss, style_loss, total_loss = compute_loss(
            generated_features, content_features, style_features, bg_masks, alpha, beta
        )
        total_loss.backward()
        return total_loss
    
    for _ in range(iterations):
        optimizer.step(closure)
        if apply_to_background:
            with torch.no_grad():
                fg_mask = F.interpolate(1 - bg_masks[0], size=generated_image.shape[2:], mode='nearest')
                generated_image.data.mul_(1 - fg_mask).add_(content_image.data * fg_mask)
                
    return generated_image