hieupt commited on
Commit
7c4166f
·
1 Parent(s): e9f7887

first commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ models/** filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+ import utils
6
+ from typing import Tuple, Dict
7
+ from model import TransformerNet
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+
11
+ # Get model
12
+ model_dir = '/models'
13
+ models = list(Path(model_dir).glob("*/*.pth.tar"))
14
+ models = sorted(models)
15
+
16
+ # Get style image
17
+ style_dir = '/style-images'
18
+ style_list = list(Path(style_dir).glob("*"))
19
+ style_list = sorted(style_list)
20
+
21
+ # Get examples
22
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
23
+
24
+ def transfer(image, model):
25
+ device = 'cpu'
26
+
27
+ width = image.size[0]
28
+ height = image.size[1]
29
+
30
+ if width > 750 or height > 500:
31
+ iamge = image.thumbnail((712, 474))
32
+
33
+ # load model
34
+ style_model = TransformerNet()
35
+ state_dict = torch.load(models[int(model)], map_location=torch.device('cpu'))
36
+ style_model.load_state_dict(state_dict["state_dict"])
37
+
38
+ content_transform = transforms.Compose([
39
+ transforms.ToTensor(),
40
+ transforms.Lambda(lambda x: x.mul(255))
41
+ ])
42
+ content_image = content_transform(image)
43
+ content_image = content_image.unsqueeze(0).to(device)
44
+
45
+ style_model.eval()
46
+ with torch.no_grad():
47
+ style_model.to(device)
48
+ output = style_model(content_image).cpu()
49
+
50
+ img = utils.deprocess(output[0])
51
+ img = Image.fromarray(img)
52
+ return img, style_list[int(model)]
53
+
54
+ # Create title, description and article strings
55
+ title = "Image Style Transfer"
56
+ description = "Choose a image that you want to transfer and the corresponding style. The app will be transfer your image. You will have received new image."
57
+ article = "Model have created base on paper [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/pdf/1603.08155v1.pdf)."
58
+
59
+ image_output_1 = gr.Image(label='Tranfer') # output result
60
+ image_output_2 = gr.Image(label='Style Image') # Show style image
61
+
62
+ # Create the Gradio demo
63
+ demo = gr.Interface(fn=transfer, # mapping function from input to output
64
+ inputs=[gr.Image(type="pil", label='Input'),
65
+ gr.Dropdown(choices=[i.parent.name for i in models], value='rain_princess', type='index', label="Style", info="Chooses kind of style image")], # what are the inputs?
66
+ outputs=[image_output_1, image_output_2], # our fn has two outputs, therefore we have two outputs
67
+ label = ['One', "Two"],
68
+ examples=example_list,
69
+ title=title,
70
+ description=description,
71
+ article=article)
72
+
73
+ # Launch the demo!
74
+ demo.launch()
examples/amber.jpg ADDED
examples/chicago.jpg ADDED
examples/golden_gate2.jpg ADDED
examples/lion.jpg ADDED
model.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class Residual_block(nn.Module):
5
+ """Residual block
6
+ Architecture: https://arxiv.org/pdf/1610.02915.pdf
7
+ """
8
+ def __init__(self, channel):
9
+ super(Residual_block, self).__init__()
10
+ self.conv_1 = nn.Conv2d(in_channels=channel, out_channels=channel,
11
+ padding='same', kernel_size=3, stride=1)
12
+ self.inst1 = nn.InstanceNorm2d(channel, affine=True)
13
+ self.conv_2 = nn.Conv2d(in_channels=channel, out_channels=channel,
14
+ padding='same', kernel_size=3, stride=1)
15
+ self.inst2 = nn.InstanceNorm2d(channel, affine=True)
16
+ self.relu = nn.ReLU()
17
+
18
+ def forward(self, x):
19
+ residual = x
20
+ out = self.relu(self.inst1(self.conv_1(x)))
21
+ out = self.inst2(self.conv_2(out))
22
+ return self.relu(out + residual)
23
+
24
+ class TransformerNet(nn.Module):
25
+ def __init__(self):
26
+ super(TransformerNet, self).__init__()
27
+ # Downsampling
28
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=9, stride=1, padding = 9//2)
29
+ self.BN_1 = nn.InstanceNorm2d(num_features=32, affine=True)
30
+ self.down_1 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding = 1)
31
+ self.BN_2 = nn.InstanceNorm2d(num_features=64, affine=True)
32
+ self.down_2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding = 1)
33
+ self.BN_3 = nn.InstanceNorm2d(num_features=128, affine=True)
34
+ # Residual connect
35
+ self.res_1 = Residual_block(128)
36
+ self.res_2 = Residual_block(128)
37
+ self.res_3 = Residual_block(128)
38
+ self.res_4 = Residual_block(128)
39
+ self.res_5 = Residual_block(128)
40
+ # Upsampling
41
+ self.up_1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding= 1)
42
+ self.BN_4 = nn.InstanceNorm2d(num_features=64, affine=True)
43
+ self.up_2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding = 1, output_padding= 1)
44
+ self.BN_5 = nn.InstanceNorm2d(num_features=32, affine=True)
45
+ self.conv2 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=9, stride=1, padding = 9//2)
46
+
47
+ self.relu = nn.ReLU()
48
+
49
+
50
+
51
+ def forward(self, x):
52
+ y = self.relu(self.BN_1(self.conv1(x)))
53
+ # print(y.shape)
54
+ y = self.relu(self.BN_2(self.down_1(y)))
55
+ # print(y.shape)
56
+ y = self.relu(self.BN_3(self.down_2(y)))
57
+ # print(y.shape)
58
+
59
+ # print()
60
+ y = self.res_1(y)
61
+ # print(y.shape)
62
+ y = self.res_2(y)
63
+ # print(y.shape)
64
+ y = self.res_3(y)
65
+ # print(y.shape)
66
+ y = self.res_4(y)
67
+ # print(y.shape)
68
+ y = self.res_5(y)
69
+ # print(y.shape)
70
+
71
+ # print()
72
+ y = self.relu(self.BN_4(self.up_1(y)))
73
+ # print(y.shape)
74
+ y = self.relu(self.BN_5(self.up_2(y)))
75
+ # print(y.shape)
76
+ y = self.conv2(y)
77
+ # print(y.shape)
78
+ return y
models/candy/candy_Epoch_3_Batch idx_4999.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b81480cf7e7f55c17157afaf87d811d3f1dee2fe458624377c121f4c244304c
3
+ size 20227039
models/mosaic/mosaic_Epoch_6_Batch idx_3999.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15193c2aed79e9972214fc15ab5055a2b327a767fefaed74c6feecd29711af35
3
+ size 20227039
models/rain_princess/rain_princess.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5ca9cdeedf0726ba3f6d5190cf33810e389c209b5fee841ee339a06140eb49e
3
+ size 20227039
models/vg_la_coffe/vg_la_cafe_Epoch_6_Batch idx_3999.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50bca2c1e24d4416e2192962e7dcfb64bb7ffeda7691835323ae43858d658907
3
+ size 20227039
models/wave_crop/wave_crop_Epoch_4_Batch idx_2999.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8371c79a6138a85f14d388aa72fe251ff90facc91482c5e8a8a919245bf542b4
3
+ size 20227039
models/weeping_woman/woman_Epoch_9_Batch idx_3999.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c481ad2234c30bc9a5858574e06284ea6542828de8b124fe1308a19e6c24f48
3
+ size 20227039
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.1.4
style_images/candy.jpg ADDED
style_images/mosaic.jpg ADDED
style_images/rain_princess.jpeg ADDED
style_images/vg_la_cafe.jpg ADDED
style_images/wave_crop.jpg ADDED
style_images/weeping_woman_by_pablo_picasso.jpg ADDED
utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+
6
+ mean = [0.4763, 0.4507, 0.4094]
7
+ std = [0.2702, 0.2652, 0.2811]
8
+
9
+ def load_image(filename, size=None):
10
+ img = Image.open(filename).convert('RGB')
11
+ if size is not None:
12
+ img = img.resize((size, size), Image.ANTIALIAS)
13
+ return img
14
+
15
+
16
+ class UnNormalize(object):
17
+ def __init__(self, mean, std):
18
+ self.mean = mean
19
+ self.std = std
20
+
21
+ def __call__(self, tensor):
22
+ """
23
+ Args:
24
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
25
+ Returns:
26
+ Tensor: Normalized image.
27
+ """
28
+ for t, m, s in zip(tensor, self.mean, self.std):
29
+ t.mul_(s).add_(m)
30
+ # The normalize code -> t.sub_(m).div_(s)
31
+ return tensor
32
+
33
+ def deprocess(image_tensor):
34
+ """ Denormalizes and rescales image tensor """
35
+ unnorm = UnNormalize(mean=mean, std=std)
36
+ img = image_tensor
37
+ unnorm(img)
38
+ img *= 255
39
+ image_np = torch.clamp(img, 0, 255).numpy().astype(np.uint8)
40
+ image_np = image_np.transpose(1, 2, 0)
41
+ return image_np
42
+
43
+ def save_image(filename, data):
44
+ img = deprocess(data)
45
+ img = Image.fromarray(img)
46
+ img.save(filename)
47
+
48
+
49
+ def gram_matrix(y):
50
+ (b, ch, h, w) = y.size()
51
+ features = y.view(b, ch, w * h)
52
+ features_t = features.transpose(1, 2)
53
+ gram = features.bmm(features_t) / (ch * h * w)
54
+ return gram
vgg.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from collections import namedtuple
3
+ from torchvision.models import vgg16, VGG16_Weights
4
+
5
+ class VGG16(nn.Module):
6
+ def __init__(self, requires_grad=False):
7
+ super(VGG16, self).__init__()
8
+
9
+ weights = VGG16_Weights.DEFAULT
10
+ vgg_pretrained_features = vgg16(weights=weights).features
11
+ self.slice1 = nn.Sequential()
12
+ self.slice2 = nn.Sequential()
13
+ self.slice3 = nn.Sequential()
14
+ self.slice4 = nn.Sequential()
15
+ for x in range(4):
16
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
17
+ for x in range(4, 9):
18
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
19
+ for x in range(9, 16):
20
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
21
+ for x in range(16, 23):
22
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
23
+ if not requires_grad:
24
+ for param in self.parameters():
25
+ param.requires_grad = False
26
+
27
+ def forward(self, X):
28
+ h = self.slice1(X)
29
+ h_relu1_2 = h
30
+ h = self.slice2(h)
31
+ h_relu2_2 = h
32
+ h = self.slice3(h)
33
+ h_relu3_3 = h
34
+ h = self.slice4(h)
35
+ h_relu4_3 = h
36
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
37
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
38
+ return out