|
import torch |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
|
|
def gram_matrix(feature): |
|
b, c, h, w = feature.size() |
|
feature = feature.view(b * c, h * w) |
|
return feature @ feature.t() |
|
|
|
def compute_loss(generated, content, style, bg_masks, alpha, beta): |
|
content_loss = sum(F.mse_loss(gf, cf) for gf, cf in zip(generated, content)) |
|
style_loss = sum( |
|
F.mse_loss( |
|
gram_matrix(gf * bg) if bg is not None else gram_matrix(gf), |
|
gram_matrix(sf * bg) if bg is not None else gram_matrix(sf), |
|
) / len(generated) |
|
for gf, sf, bg in zip(generated, style, bg_masks or [None] * len(generated)) |
|
) |
|
return alpha * content_loss, beta * style_loss, alpha * content_loss + beta * style_loss |
|
|
|
def inference( |
|
*, |
|
model, |
|
sod_model, |
|
content_image, |
|
content_image_norm, |
|
style_features, |
|
apply_to_background, |
|
lr=1.5e-2, |
|
iterations=51, |
|
optim_caller=optim.AdamW, |
|
alpha=1, |
|
beta=1, |
|
): |
|
generated_image = content_image.clone().requires_grad_(True) |
|
optimizer = optim_caller([generated_image], lr=lr) |
|
|
|
with torch.no_grad(): |
|
content_features = model(content_image) |
|
bg_masks = None |
|
|
|
if apply_to_background: |
|
seg_output = torch.sigmoid(sod_model(content_image_norm)[0]) |
|
bg_mask = (seg_output <= 0.7).float() |
|
bg_masks = [ |
|
F.interpolate(bg_mask.unsqueeze(1), size=cf.shape[2:], mode='bilinear', align_corners=False) |
|
for cf in content_features |
|
] |
|
|
|
def closure(): |
|
optimizer.zero_grad() |
|
generated_features = model(generated_image) |
|
content_loss, style_loss, total_loss = compute_loss( |
|
generated_features, content_features, style_features, bg_masks, alpha, beta |
|
) |
|
total_loss.backward() |
|
return total_loss |
|
|
|
for _ in range(iterations): |
|
optimizer.step(closure) |
|
if apply_to_background: |
|
with torch.no_grad(): |
|
fg_mask = F.interpolate(1 - bg_masks[0], size=generated_image.shape[2:], mode='nearest') |
|
generated_image.data.mul_(1 - fg_mask).add_(content_image.data * fg_mask) |
|
|
|
return generated_image |
|
|