jamino30 commited on
Commit
a1732e3
·
verified ·
1 Parent(s): a3814f8

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +17 -36
  2. utils.py +9 -1
app.py CHANGED
@@ -4,13 +4,11 @@ from datetime import datetime, timezone, timedelta
4
 
5
  import spaces
6
  import torch
7
- import torch.optim as optim
8
  import numpy as np
9
  import gradio as gr
10
- from safetensors.torch import load_file
11
  from huggingface_hub import hf_hub_download
12
 
13
- from utils import preprocess_img, postprocess_img
14
  from vgg.vgg19 import VGG_19
15
  from u2net.model import U2Net
16
  from inference import inference
@@ -20,22 +18,20 @@ elif torch.backends.mps.is_available(): device = 'mps'
20
  else: device = 'cpu'
21
  print('Device:', device)
22
  if device == 'cuda': print('Name:', torch.cuda.get_device_name())
23
-
24
- def load_model_without_module(model, model_path):
25
- state_dict = load_file(model_path, device=device)
26
- new_state_dict = {}
27
- for k, v in state_dict.items():
28
- name = k[7:] if k.startswith('module.') else k
29
- new_state_dict[name] = v
30
- model.load_state_dict(new_state_dict)
31
 
32
  # load models
33
  model = VGG_19().to(device).eval()
34
  for param in model.parameters():
35
  param.requires_grad = False
36
  sod_model = U2Net().to(device).eval()
37
- local_model_path = hf_hub_download(repo_id='jamino30/u2net-saliency', filename='u2net-duts-msra.safetensors')
38
- load_model_without_module(sod_model, local_model_path)
 
 
 
 
 
 
39
 
40
  style_files = os.listdir('./style_images')
41
  style_options = {
@@ -52,26 +48,19 @@ style_options = {
52
  lrs = np.linspace(0.015, 0.075, 3).tolist()
53
  img_size = 512
54
 
55
- cached_style_features = {}
56
- for style_name, style_img_path in style_options.items():
57
- style_img = preprocess_img(style_img_path, img_size)[0].to(device)
58
- with torch.no_grad():
59
- style_features = model(style_img)
60
- cached_style_features[style_name] = style_features
61
 
62
- @spaces.GPU(duration=30)
63
- def run(content_image, style_name, style_strength=len(lrs), optim_name='AdamW', apply_to_background=False):
64
  yield None
65
  content_img, original_size = preprocess_img(content_image, img_size)
66
  content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
67
  content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
68
  style_features = cached_style_features[style_name]
69
 
70
- if optim_name == 'AdamW':
71
- optim_caller = optim.AdamW
72
- elif optim_name == 'L-BFGS':
73
- optim_caller = optim.LBFGS
74
-
75
  print('-'*30)
76
  print(datetime.now(timezone.utc) - timedelta(hours=5)) # EST
77
 
@@ -84,16 +73,11 @@ def run(content_image, style_name, style_strength=len(lrs), optim_name='AdamW',
84
  style_features=style_features,
85
  lr=lrs[style_strength-1],
86
  apply_to_background=apply_to_background,
87
- optim_caller=optim_caller,
88
  )
89
- et = time.time()
90
- print(f'{et-st:.2f}s')
91
 
92
  yield postprocess_img(generated_img, original_size)
93
 
94
- def set_slider(value):
95
- return gr.update(value=value)
96
-
97
  css = """
98
  #container {
99
  margin: 0 auto;
@@ -111,13 +95,10 @@ with gr.Blocks(css=css) as demo:
111
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
112
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=len(lrs), step=1, value=len(lrs))
113
  apply_to_background_checkbox = gr.Checkbox(label='Apply style transfer exclusively to the background', value=False)
114
- with gr.Accordion(label='Advanced Options', open=False):
115
- optim_dropdown = gr.Radio(choices=['AdamW', 'L-BFGS'], label='Optimizer', value='AdamW', type='value')
116
  submit_button = gr.Button('Submit', variant='primary')
117
 
118
  examples = gr.Examples(
119
  examples=[
120
- ['./content_images/Surfer.jpg', 'Starry Night'],
121
  ['./content_images/GoldenRetriever.jpg', 'Great Wave'],
122
  ['./content_images/CameraGirl.jpg', 'Bokeh']
123
  ],
@@ -140,7 +121,7 @@ with gr.Blocks(css=css) as demo:
140
 
141
  submit_button.click(
142
  fn=run,
143
- inputs=[content_image, style_dropdown, style_strength_slider, optim_dropdown, apply_to_background_checkbox],
144
  outputs=output_image
145
  ).then(
146
  fn=save_image,
 
4
 
5
  import spaces
6
  import torch
 
7
  import numpy as np
8
  import gradio as gr
 
9
  from huggingface_hub import hf_hub_download
10
 
11
+ from utils import preprocess_img, postprocess_img, load_model_without_module
12
  from vgg.vgg19 import VGG_19
13
  from u2net.model import U2Net
14
  from inference import inference
 
18
  else: device = 'cpu'
19
  print('Device:', device)
20
  if device == 'cuda': print('Name:', torch.cuda.get_device_name())
 
 
 
 
 
 
 
 
21
 
22
  # load models
23
  model = VGG_19().to(device).eval()
24
  for param in model.parameters():
25
  param.requires_grad = False
26
  sod_model = U2Net().to(device).eval()
27
+ load_model_without_module(
28
+ sod_model,
29
+ hf_hub_download(repo_id='jamino30/u2net-saliency', filename='u2net-duts-msra.safetensors'),
30
+ device=device
31
+ )
32
+
33
+ model = torch.jit.script(model)
34
+ sod_model = torch.jit.script(sod_model)
35
 
36
  style_files = os.listdir('./style_images')
37
  style_options = {
 
48
  lrs = np.linspace(0.015, 0.075, 3).tolist()
49
  img_size = 512
50
 
51
+ cached_style_features = {
52
+ style_name: model(preprocess_img(style_img_path, img_size)[0].to(device))
53
+ for style_name, style_img_path in style_options.items()
54
+ }
 
 
55
 
56
+ @spaces.GPU(duration=15)
57
+ def run(content_image, style_name, style_strength=len(lrs), apply_to_background=False):
58
  yield None
59
  content_img, original_size = preprocess_img(content_image, img_size)
60
  content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
61
  content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
62
  style_features = cached_style_features[style_name]
63
 
 
 
 
 
 
64
  print('-'*30)
65
  print(datetime.now(timezone.utc) - timedelta(hours=5)) # EST
66
 
 
73
  style_features=style_features,
74
  lr=lrs[style_strength-1],
75
  apply_to_background=apply_to_background,
 
76
  )
77
+ print(f'{time.time()-st:.2f}s')
 
78
 
79
  yield postprocess_img(generated_img, original_size)
80
 
 
 
 
81
  css = """
82
  #container {
83
  margin: 0 auto;
 
95
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
96
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=len(lrs), step=1, value=len(lrs))
97
  apply_to_background_checkbox = gr.Checkbox(label='Apply style transfer exclusively to the background', value=False)
 
 
98
  submit_button = gr.Button('Submit', variant='primary')
99
 
100
  examples = gr.Examples(
101
  examples=[
 
102
  ['./content_images/GoldenRetriever.jpg', 'Great Wave'],
103
  ['./content_images/CameraGirl.jpg', 'Bokeh']
104
  ],
 
121
 
122
  submit_button.click(
123
  fn=run,
124
+ inputs=[content_image, style_dropdown, style_strength_slider, apply_to_background_checkbox],
125
  outputs=output_image
126
  ).then(
127
  fn=save_image,
utils.py CHANGED
@@ -2,6 +2,7 @@ from PIL import Image
2
 
3
  import torch
4
  import torchvision.transforms as transforms
 
5
 
6
  def preprocess_img(img, img_size, normalize=False):
7
  if type(img) == str: img = Image.open(img)
@@ -33,4 +34,11 @@ def postprocess_img(img, original_size, normalize=False):
33
 
34
  img = transforms.ToPILImage()(img)
35
  img = img.resize(original_size, Image.Resampling.LANCZOS)
36
- return img
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  import torchvision.transforms as transforms
5
+ from safetensors.torch import load_file
6
 
7
  def preprocess_img(img, img_size, normalize=False):
8
  if type(img) == str: img = Image.open(img)
 
34
 
35
  img = transforms.ToPILImage()(img)
36
  img = img.resize(original_size, Image.Resampling.LANCZOS)
37
+ return img
38
+
39
+ def load_model_without_module(model, model_path, device):
40
+ state_dict = {
41
+ k[7:] if k.startswith('module.') else k: v
42
+ for k, v in load_file(model_path, device=device).items()
43
+ }
44
+ model.load_state_dict(state_dict)