Spaces:
Runtime error
Runtime error
first commit
Browse files- .gitattributes +1 -0
- app.py +74 -0
- examples/amber.jpg +0 -0
- examples/chicago.jpg +0 -0
- examples/golden_gate2.jpg +0 -0
- examples/lion.jpg +0 -0
- model.py +78 -0
- models/candy/candy_Epoch_3_Batch idx_4999.pth.tar +3 -0
- models/mosaic/mosaic_Epoch_6_Batch idx_3999.pth.tar +3 -0
- models/rain_princess/rain_princess.pth.tar +3 -0
- models/vg_la_coffe/vg_la_cafe_Epoch_6_Batch idx_3999.pth.tar +3 -0
- models/wave_crop/wave_crop_Epoch_4_Batch idx_2999.pth.tar +3 -0
- models/weeping_woman/woman_Epoch_9_Batch idx_3999.pth.tar +3 -0
- requirements.txt +3 -0
- style_images/candy.jpg +0 -0
- style_images/mosaic.jpg +0 -0
- style_images/rain_princess.jpeg +0 -0
- style_images/vg_la_cafe.jpg +0 -0
- style_images/wave_crop.jpg +0 -0
- style_images/weeping_woman_by_pablo_picasso.jpg +0 -0
- utils.py +54 -0
- vgg.py +38 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
models/** filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 1. Imports and class names setup ###
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import utils
|
6 |
+
from typing import Tuple, Dict
|
7 |
+
from model import TransformerNet
|
8 |
+
from torchvision import transforms
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
# Get model
|
12 |
+
model_dir = '/models'
|
13 |
+
models = list(Path(model_dir).glob("*/*.pth.tar"))
|
14 |
+
models = sorted(models)
|
15 |
+
|
16 |
+
# Get style image
|
17 |
+
style_dir = '/style-images'
|
18 |
+
style_list = list(Path(style_dir).glob("*"))
|
19 |
+
style_list = sorted(style_list)
|
20 |
+
|
21 |
+
# Get examples
|
22 |
+
example_list = [["examples/" + example] for example in os.listdir("examples")]
|
23 |
+
|
24 |
+
def transfer(image, model):
|
25 |
+
device = 'cpu'
|
26 |
+
|
27 |
+
width = image.size[0]
|
28 |
+
height = image.size[1]
|
29 |
+
|
30 |
+
if width > 750 or height > 500:
|
31 |
+
iamge = image.thumbnail((712, 474))
|
32 |
+
|
33 |
+
# load model
|
34 |
+
style_model = TransformerNet()
|
35 |
+
state_dict = torch.load(models[int(model)], map_location=torch.device('cpu'))
|
36 |
+
style_model.load_state_dict(state_dict["state_dict"])
|
37 |
+
|
38 |
+
content_transform = transforms.Compose([
|
39 |
+
transforms.ToTensor(),
|
40 |
+
transforms.Lambda(lambda x: x.mul(255))
|
41 |
+
])
|
42 |
+
content_image = content_transform(image)
|
43 |
+
content_image = content_image.unsqueeze(0).to(device)
|
44 |
+
|
45 |
+
style_model.eval()
|
46 |
+
with torch.no_grad():
|
47 |
+
style_model.to(device)
|
48 |
+
output = style_model(content_image).cpu()
|
49 |
+
|
50 |
+
img = utils.deprocess(output[0])
|
51 |
+
img = Image.fromarray(img)
|
52 |
+
return img, style_list[int(model)]
|
53 |
+
|
54 |
+
# Create title, description and article strings
|
55 |
+
title = "Image Style Transfer"
|
56 |
+
description = "Choose a image that you want to transfer and the corresponding style. The app will be transfer your image. You will have received new image."
|
57 |
+
article = "Model have created base on paper [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/pdf/1603.08155v1.pdf)."
|
58 |
+
|
59 |
+
image_output_1 = gr.Image(label='Tranfer') # output result
|
60 |
+
image_output_2 = gr.Image(label='Style Image') # Show style image
|
61 |
+
|
62 |
+
# Create the Gradio demo
|
63 |
+
demo = gr.Interface(fn=transfer, # mapping function from input to output
|
64 |
+
inputs=[gr.Image(type="pil", label='Input'),
|
65 |
+
gr.Dropdown(choices=[i.parent.name for i in models], value='rain_princess', type='index', label="Style", info="Chooses kind of style image")], # what are the inputs?
|
66 |
+
outputs=[image_output_1, image_output_2], # our fn has two outputs, therefore we have two outputs
|
67 |
+
label = ['One', "Two"],
|
68 |
+
examples=example_list,
|
69 |
+
title=title,
|
70 |
+
description=description,
|
71 |
+
article=article)
|
72 |
+
|
73 |
+
# Launch the demo!
|
74 |
+
demo.launch()
|
examples/amber.jpg
ADDED
![]() |
examples/chicago.jpg
ADDED
![]() |
examples/golden_gate2.jpg
ADDED
![]() |
examples/lion.jpg
ADDED
![]() |
model.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class Residual_block(nn.Module):
|
5 |
+
"""Residual block
|
6 |
+
Architecture: https://arxiv.org/pdf/1610.02915.pdf
|
7 |
+
"""
|
8 |
+
def __init__(self, channel):
|
9 |
+
super(Residual_block, self).__init__()
|
10 |
+
self.conv_1 = nn.Conv2d(in_channels=channel, out_channels=channel,
|
11 |
+
padding='same', kernel_size=3, stride=1)
|
12 |
+
self.inst1 = nn.InstanceNorm2d(channel, affine=True)
|
13 |
+
self.conv_2 = nn.Conv2d(in_channels=channel, out_channels=channel,
|
14 |
+
padding='same', kernel_size=3, stride=1)
|
15 |
+
self.inst2 = nn.InstanceNorm2d(channel, affine=True)
|
16 |
+
self.relu = nn.ReLU()
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
residual = x
|
20 |
+
out = self.relu(self.inst1(self.conv_1(x)))
|
21 |
+
out = self.inst2(self.conv_2(out))
|
22 |
+
return self.relu(out + residual)
|
23 |
+
|
24 |
+
class TransformerNet(nn.Module):
|
25 |
+
def __init__(self):
|
26 |
+
super(TransformerNet, self).__init__()
|
27 |
+
# Downsampling
|
28 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=9, stride=1, padding = 9//2)
|
29 |
+
self.BN_1 = nn.InstanceNorm2d(num_features=32, affine=True)
|
30 |
+
self.down_1 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding = 1)
|
31 |
+
self.BN_2 = nn.InstanceNorm2d(num_features=64, affine=True)
|
32 |
+
self.down_2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding = 1)
|
33 |
+
self.BN_3 = nn.InstanceNorm2d(num_features=128, affine=True)
|
34 |
+
# Residual connect
|
35 |
+
self.res_1 = Residual_block(128)
|
36 |
+
self.res_2 = Residual_block(128)
|
37 |
+
self.res_3 = Residual_block(128)
|
38 |
+
self.res_4 = Residual_block(128)
|
39 |
+
self.res_5 = Residual_block(128)
|
40 |
+
# Upsampling
|
41 |
+
self.up_1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding= 1)
|
42 |
+
self.BN_4 = nn.InstanceNorm2d(num_features=64, affine=True)
|
43 |
+
self.up_2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding = 1, output_padding= 1)
|
44 |
+
self.BN_5 = nn.InstanceNorm2d(num_features=32, affine=True)
|
45 |
+
self.conv2 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=9, stride=1, padding = 9//2)
|
46 |
+
|
47 |
+
self.relu = nn.ReLU()
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
y = self.relu(self.BN_1(self.conv1(x)))
|
53 |
+
# print(y.shape)
|
54 |
+
y = self.relu(self.BN_2(self.down_1(y)))
|
55 |
+
# print(y.shape)
|
56 |
+
y = self.relu(self.BN_3(self.down_2(y)))
|
57 |
+
# print(y.shape)
|
58 |
+
|
59 |
+
# print()
|
60 |
+
y = self.res_1(y)
|
61 |
+
# print(y.shape)
|
62 |
+
y = self.res_2(y)
|
63 |
+
# print(y.shape)
|
64 |
+
y = self.res_3(y)
|
65 |
+
# print(y.shape)
|
66 |
+
y = self.res_4(y)
|
67 |
+
# print(y.shape)
|
68 |
+
y = self.res_5(y)
|
69 |
+
# print(y.shape)
|
70 |
+
|
71 |
+
# print()
|
72 |
+
y = self.relu(self.BN_4(self.up_1(y)))
|
73 |
+
# print(y.shape)
|
74 |
+
y = self.relu(self.BN_5(self.up_2(y)))
|
75 |
+
# print(y.shape)
|
76 |
+
y = self.conv2(y)
|
77 |
+
# print(y.shape)
|
78 |
+
return y
|
models/candy/candy_Epoch_3_Batch idx_4999.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b81480cf7e7f55c17157afaf87d811d3f1dee2fe458624377c121f4c244304c
|
3 |
+
size 20227039
|
models/mosaic/mosaic_Epoch_6_Batch idx_3999.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:15193c2aed79e9972214fc15ab5055a2b327a767fefaed74c6feecd29711af35
|
3 |
+
size 20227039
|
models/rain_princess/rain_princess.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5ca9cdeedf0726ba3f6d5190cf33810e389c209b5fee841ee339a06140eb49e
|
3 |
+
size 20227039
|
models/vg_la_coffe/vg_la_cafe_Epoch_6_Batch idx_3999.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50bca2c1e24d4416e2192962e7dcfb64bb7ffeda7691835323ae43858d658907
|
3 |
+
size 20227039
|
models/wave_crop/wave_crop_Epoch_4_Batch idx_2999.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8371c79a6138a85f14d388aa72fe251ff90facc91482c5e8a8a919245bf542b4
|
3 |
+
size 20227039
|
models/weeping_woman/woman_Epoch_9_Batch idx_3999.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c481ad2234c30bc9a5858574e06284ea6542828de8b124fe1308a19e6c24f48
|
3 |
+
size 20227039
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.12.0
|
2 |
+
torchvision==0.13.0
|
3 |
+
gradio==3.1.4
|
style_images/candy.jpg
ADDED
![]() |
style_images/mosaic.jpg
ADDED
![]() |
style_images/rain_princess.jpeg
ADDED
![]() |
style_images/vg_la_cafe.jpg
ADDED
![]() |
style_images/wave_crop.jpg
ADDED
![]() |
style_images/weeping_woman_by_pablo_picasso.jpg
ADDED
![]() |
utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
mean = [0.4763, 0.4507, 0.4094]
|
7 |
+
std = [0.2702, 0.2652, 0.2811]
|
8 |
+
|
9 |
+
def load_image(filename, size=None):
|
10 |
+
img = Image.open(filename).convert('RGB')
|
11 |
+
if size is not None:
|
12 |
+
img = img.resize((size, size), Image.ANTIALIAS)
|
13 |
+
return img
|
14 |
+
|
15 |
+
|
16 |
+
class UnNormalize(object):
|
17 |
+
def __init__(self, mean, std):
|
18 |
+
self.mean = mean
|
19 |
+
self.std = std
|
20 |
+
|
21 |
+
def __call__(self, tensor):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
25 |
+
Returns:
|
26 |
+
Tensor: Normalized image.
|
27 |
+
"""
|
28 |
+
for t, m, s in zip(tensor, self.mean, self.std):
|
29 |
+
t.mul_(s).add_(m)
|
30 |
+
# The normalize code -> t.sub_(m).div_(s)
|
31 |
+
return tensor
|
32 |
+
|
33 |
+
def deprocess(image_tensor):
|
34 |
+
""" Denormalizes and rescales image tensor """
|
35 |
+
unnorm = UnNormalize(mean=mean, std=std)
|
36 |
+
img = image_tensor
|
37 |
+
unnorm(img)
|
38 |
+
img *= 255
|
39 |
+
image_np = torch.clamp(img, 0, 255).numpy().astype(np.uint8)
|
40 |
+
image_np = image_np.transpose(1, 2, 0)
|
41 |
+
return image_np
|
42 |
+
|
43 |
+
def save_image(filename, data):
|
44 |
+
img = deprocess(data)
|
45 |
+
img = Image.fromarray(img)
|
46 |
+
img.save(filename)
|
47 |
+
|
48 |
+
|
49 |
+
def gram_matrix(y):
|
50 |
+
(b, ch, h, w) = y.size()
|
51 |
+
features = y.view(b, ch, w * h)
|
52 |
+
features_t = features.transpose(1, 2)
|
53 |
+
gram = features.bmm(features_t) / (ch * h * w)
|
54 |
+
return gram
|
vgg.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from collections import namedtuple
|
3 |
+
from torchvision.models import vgg16, VGG16_Weights
|
4 |
+
|
5 |
+
class VGG16(nn.Module):
|
6 |
+
def __init__(self, requires_grad=False):
|
7 |
+
super(VGG16, self).__init__()
|
8 |
+
|
9 |
+
weights = VGG16_Weights.DEFAULT
|
10 |
+
vgg_pretrained_features = vgg16(weights=weights).features
|
11 |
+
self.slice1 = nn.Sequential()
|
12 |
+
self.slice2 = nn.Sequential()
|
13 |
+
self.slice3 = nn.Sequential()
|
14 |
+
self.slice4 = nn.Sequential()
|
15 |
+
for x in range(4):
|
16 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
17 |
+
for x in range(4, 9):
|
18 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
19 |
+
for x in range(9, 16):
|
20 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
21 |
+
for x in range(16, 23):
|
22 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
23 |
+
if not requires_grad:
|
24 |
+
for param in self.parameters():
|
25 |
+
param.requires_grad = False
|
26 |
+
|
27 |
+
def forward(self, X):
|
28 |
+
h = self.slice1(X)
|
29 |
+
h_relu1_2 = h
|
30 |
+
h = self.slice2(h)
|
31 |
+
h_relu2_2 = h
|
32 |
+
h = self.slice3(h)
|
33 |
+
h_relu3_3 = h
|
34 |
+
h = self.slice4(h)
|
35 |
+
h_relu4_3 = h
|
36 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
|
37 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
|
38 |
+
return out
|