Upload folder using huggingface_hub
Browse files
app.py
CHANGED
@@ -4,13 +4,11 @@ from datetime import datetime, timezone, timedelta
|
|
4 |
|
5 |
import spaces
|
6 |
import torch
|
7 |
-
import torch.optim as optim
|
8 |
import numpy as np
|
9 |
import gradio as gr
|
10 |
-
from safetensors.torch import load_file
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
|
13 |
-
from utils import preprocess_img, postprocess_img
|
14 |
from vgg.vgg19 import VGG_19
|
15 |
from u2net.model import U2Net
|
16 |
from inference import inference
|
@@ -20,22 +18,20 @@ elif torch.backends.mps.is_available(): device = 'mps'
|
|
20 |
else: device = 'cpu'
|
21 |
print('Device:', device)
|
22 |
if device == 'cuda': print('Name:', torch.cuda.get_device_name())
|
23 |
-
|
24 |
-
def load_model_without_module(model, model_path):
|
25 |
-
state_dict = load_file(model_path, device=device)
|
26 |
-
new_state_dict = {}
|
27 |
-
for k, v in state_dict.items():
|
28 |
-
name = k[7:] if k.startswith('module.') else k
|
29 |
-
new_state_dict[name] = v
|
30 |
-
model.load_state_dict(new_state_dict)
|
31 |
|
32 |
# load models
|
33 |
model = VGG_19().to(device).eval()
|
34 |
for param in model.parameters():
|
35 |
param.requires_grad = False
|
36 |
sod_model = U2Net().to(device).eval()
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
style_files = os.listdir('./style_images')
|
41 |
style_options = {
|
@@ -52,26 +48,19 @@ style_options = {
|
|
52 |
lrs = np.linspace(0.015, 0.075, 3).tolist()
|
53 |
img_size = 512
|
54 |
|
55 |
-
cached_style_features = {
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
style_features = model(style_img)
|
60 |
-
cached_style_features[style_name] = style_features
|
61 |
|
62 |
-
@spaces.GPU(duration=
|
63 |
-
def run(content_image, style_name, style_strength=len(lrs),
|
64 |
yield None
|
65 |
content_img, original_size = preprocess_img(content_image, img_size)
|
66 |
content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
|
67 |
content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
|
68 |
style_features = cached_style_features[style_name]
|
69 |
|
70 |
-
if optim_name == 'AdamW':
|
71 |
-
optim_caller = optim.AdamW
|
72 |
-
elif optim_name == 'L-BFGS':
|
73 |
-
optim_caller = optim.LBFGS
|
74 |
-
|
75 |
print('-'*30)
|
76 |
print(datetime.now(timezone.utc) - timedelta(hours=5)) # EST
|
77 |
|
@@ -84,16 +73,11 @@ def run(content_image, style_name, style_strength=len(lrs), optim_name='AdamW',
|
|
84 |
style_features=style_features,
|
85 |
lr=lrs[style_strength-1],
|
86 |
apply_to_background=apply_to_background,
|
87 |
-
optim_caller=optim_caller,
|
88 |
)
|
89 |
-
|
90 |
-
print(f'{et-st:.2f}s')
|
91 |
|
92 |
yield postprocess_img(generated_img, original_size)
|
93 |
|
94 |
-
def set_slider(value):
|
95 |
-
return gr.update(value=value)
|
96 |
-
|
97 |
css = """
|
98 |
#container {
|
99 |
margin: 0 auto;
|
@@ -111,13 +95,10 @@ with gr.Blocks(css=css) as demo:
|
|
111 |
style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
|
112 |
style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=len(lrs), step=1, value=len(lrs))
|
113 |
apply_to_background_checkbox = gr.Checkbox(label='Apply style transfer exclusively to the background', value=False)
|
114 |
-
with gr.Accordion(label='Advanced Options', open=False):
|
115 |
-
optim_dropdown = gr.Radio(choices=['AdamW', 'L-BFGS'], label='Optimizer', value='AdamW', type='value')
|
116 |
submit_button = gr.Button('Submit', variant='primary')
|
117 |
|
118 |
examples = gr.Examples(
|
119 |
examples=[
|
120 |
-
['./content_images/Surfer.jpg', 'Starry Night'],
|
121 |
['./content_images/GoldenRetriever.jpg', 'Great Wave'],
|
122 |
['./content_images/CameraGirl.jpg', 'Bokeh']
|
123 |
],
|
@@ -140,7 +121,7 @@ with gr.Blocks(css=css) as demo:
|
|
140 |
|
141 |
submit_button.click(
|
142 |
fn=run,
|
143 |
-
inputs=[content_image, style_dropdown, style_strength_slider,
|
144 |
outputs=output_image
|
145 |
).then(
|
146 |
fn=save_image,
|
|
|
4 |
|
5 |
import spaces
|
6 |
import torch
|
|
|
7 |
import numpy as np
|
8 |
import gradio as gr
|
|
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
|
11 |
+
from utils import preprocess_img, postprocess_img, load_model_without_module
|
12 |
from vgg.vgg19 import VGG_19
|
13 |
from u2net.model import U2Net
|
14 |
from inference import inference
|
|
|
18 |
else: device = 'cpu'
|
19 |
print('Device:', device)
|
20 |
if device == 'cuda': print('Name:', torch.cuda.get_device_name())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# load models
|
23 |
model = VGG_19().to(device).eval()
|
24 |
for param in model.parameters():
|
25 |
param.requires_grad = False
|
26 |
sod_model = U2Net().to(device).eval()
|
27 |
+
load_model_without_module(
|
28 |
+
sod_model,
|
29 |
+
hf_hub_download(repo_id='jamino30/u2net-saliency', filename='u2net-duts-msra.safetensors'),
|
30 |
+
device=device
|
31 |
+
)
|
32 |
+
|
33 |
+
model = torch.jit.script(model)
|
34 |
+
sod_model = torch.jit.script(sod_model)
|
35 |
|
36 |
style_files = os.listdir('./style_images')
|
37 |
style_options = {
|
|
|
48 |
lrs = np.linspace(0.015, 0.075, 3).tolist()
|
49 |
img_size = 512
|
50 |
|
51 |
+
cached_style_features = {
|
52 |
+
style_name: model(preprocess_img(style_img_path, img_size)[0].to(device))
|
53 |
+
for style_name, style_img_path in style_options.items()
|
54 |
+
}
|
|
|
|
|
55 |
|
56 |
+
@spaces.GPU(duration=15)
|
57 |
+
def run(content_image, style_name, style_strength=len(lrs), apply_to_background=False):
|
58 |
yield None
|
59 |
content_img, original_size = preprocess_img(content_image, img_size)
|
60 |
content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
|
61 |
content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
|
62 |
style_features = cached_style_features[style_name]
|
63 |
|
|
|
|
|
|
|
|
|
|
|
64 |
print('-'*30)
|
65 |
print(datetime.now(timezone.utc) - timedelta(hours=5)) # EST
|
66 |
|
|
|
73 |
style_features=style_features,
|
74 |
lr=lrs[style_strength-1],
|
75 |
apply_to_background=apply_to_background,
|
|
|
76 |
)
|
77 |
+
print(f'{time.time()-st:.2f}s')
|
|
|
78 |
|
79 |
yield postprocess_img(generated_img, original_size)
|
80 |
|
|
|
|
|
|
|
81 |
css = """
|
82 |
#container {
|
83 |
margin: 0 auto;
|
|
|
95 |
style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
|
96 |
style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=len(lrs), step=1, value=len(lrs))
|
97 |
apply_to_background_checkbox = gr.Checkbox(label='Apply style transfer exclusively to the background', value=False)
|
|
|
|
|
98 |
submit_button = gr.Button('Submit', variant='primary')
|
99 |
|
100 |
examples = gr.Examples(
|
101 |
examples=[
|
|
|
102 |
['./content_images/GoldenRetriever.jpg', 'Great Wave'],
|
103 |
['./content_images/CameraGirl.jpg', 'Bokeh']
|
104 |
],
|
|
|
121 |
|
122 |
submit_button.click(
|
123 |
fn=run,
|
124 |
+
inputs=[content_image, style_dropdown, style_strength_slider, apply_to_background_checkbox],
|
125 |
outputs=output_image
|
126 |
).then(
|
127 |
fn=save_image,
|
utils.py
CHANGED
@@ -2,6 +2,7 @@ from PIL import Image
|
|
2 |
|
3 |
import torch
|
4 |
import torchvision.transforms as transforms
|
|
|
5 |
|
6 |
def preprocess_img(img, img_size, normalize=False):
|
7 |
if type(img) == str: img = Image.open(img)
|
@@ -33,4 +34,11 @@ def postprocess_img(img, original_size, normalize=False):
|
|
33 |
|
34 |
img = transforms.ToPILImage()(img)
|
35 |
img = img.resize(original_size, Image.Resampling.LANCZOS)
|
36 |
-
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import torch
|
4 |
import torchvision.transforms as transforms
|
5 |
+
from safetensors.torch import load_file
|
6 |
|
7 |
def preprocess_img(img, img_size, normalize=False):
|
8 |
if type(img) == str: img = Image.open(img)
|
|
|
34 |
|
35 |
img = transforms.ToPILImage()(img)
|
36 |
img = img.resize(original_size, Image.Resampling.LANCZOS)
|
37 |
+
return img
|
38 |
+
|
39 |
+
def load_model_without_module(model, model_path, device):
|
40 |
+
state_dict = {
|
41 |
+
k[7:] if k.startswith('module.') else k: v
|
42 |
+
for k, v in load_file(model_path, device=device).items()
|
43 |
+
}
|
44 |
+
model.load_state_dict(state_dict)
|