Spaces:
Runtime error
Runtime error
Commit
·
34fb220
1
Parent(s):
96d9168
init
Browse files- app.py +214 -0
- configs/GSSN.yaml +57 -0
- configs/SSN.yaml +51 -0
- model_utils.py +53 -0
- models/Attention.ipynb +509 -0
- models/Attention_SSN.py +218 -0
- models/Attention_Unet.py +165 -0
- models/GSSN.py +176 -0
- models/Loss/Loss.py +271 -0
- models/Loss/__init__.py +0 -0
- models/Loss/__pycache__/Loss.cpython-39.pyc +0 -0
- models/Loss/__pycache__/__init__.cpython-39.pyc +0 -0
- models/Loss/__pycache__/vgg19_loss.cpython-39.pyc +0 -0
- models/Loss/pytorch_ssim/__init__.py +73 -0
- models/Loss/pytorch_ssim/__pycache__/__init__.cpython-39.pyc +0 -0
- models/Loss/vgg19_loss.py +54 -0
- models/SSN.py +143 -0
- models/SSN_Model.py +333 -0
- models/SSN_v1.py +290 -0
- models/Sparse_PH.py +185 -0
- models/__init__.py +43 -0
- models/__pycache__/SSN.cpython-39.pyc +0 -0
- models/__pycache__/SSN_Model.cpython-39.pyc +0 -0
- models/__pycache__/__init__.cpython-39.pyc +0 -0
- models/__pycache__/abs_model.cpython-39.pyc +0 -0
- models/__pycache__/blocks.cpython-39.pyc +0 -0
- models/abs_model.py +73 -0
- models/attention.py +85 -0
- models/blocks.py +238 -0
- models/pvt_attention.py +240 -0
- models/template.py +114 -0
- weights/SSN/0000001760.pt +3 -0
app.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
import model_utils
|
| 11 |
+
from models.SSN import SSN
|
| 12 |
+
|
| 13 |
+
import matplotlib
|
| 14 |
+
matplotlib.use('TkAgg')
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
|
| 19 |
+
config_file = 'configs/SSN.yaml'
|
| 20 |
+
weight = 'weights/SSN/0000001760.pt'
|
| 21 |
+
device = torch.device('cuda:0')
|
| 22 |
+
model = model_utils.load_model(config_file, weight, SSN, device)
|
| 23 |
+
|
| 24 |
+
DEFAULT_INTENSITY = 0.9
|
| 25 |
+
DEFAULT_GAMMA = 2.0
|
| 26 |
+
|
| 27 |
+
logging.info('Model loading succeed')
|
| 28 |
+
|
| 29 |
+
cur_rgba = None
|
| 30 |
+
cur_shadow = None
|
| 31 |
+
cur_intensity = DEFAULT_INTENSITY
|
| 32 |
+
cur_gamma = DEFAULT_GAMMA
|
| 33 |
+
|
| 34 |
+
def resize(img, size):
|
| 35 |
+
h, w = img.shape[:2]
|
| 36 |
+
|
| 37 |
+
if h > w:
|
| 38 |
+
newh = size
|
| 39 |
+
neww = int(w / h * size)
|
| 40 |
+
else:
|
| 41 |
+
neww = size
|
| 42 |
+
newh = int(h / w * size)
|
| 43 |
+
|
| 44 |
+
resized_img = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA)
|
| 45 |
+
if len(img.shape) != len(resized_img.shape):
|
| 46 |
+
resized_img = resized_img[..., none]
|
| 47 |
+
|
| 48 |
+
return resized_img
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def ibl_normalize(ibl, energy=30.0):
|
| 52 |
+
total_energy = np.sum(ibl)
|
| 53 |
+
if total_energy < 1e-3:
|
| 54 |
+
# print('small energy: ', total_energy)
|
| 55 |
+
h,w = ibl.shape
|
| 56 |
+
return np.zeros((h,w))
|
| 57 |
+
|
| 58 |
+
return ibl * energy / total_energy
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def padding_mask(rgba_input: np.array):
|
| 62 |
+
""" Padding the mask input so that it fits the training dataset view range
|
| 63 |
+
|
| 64 |
+
If the rgba does not have enough padding area, we need to pad the area
|
| 65 |
+
|
| 66 |
+
:param rgba_input: H x W x 4 inputs, the first 3 channels are RGB, the last channel is the alpha
|
| 67 |
+
:returns: H x W x 4 padded RGBAD
|
| 68 |
+
|
| 69 |
+
"""
|
| 70 |
+
padding = 50
|
| 71 |
+
padding_size = 256 - padding * 2
|
| 72 |
+
|
| 73 |
+
h, w = rgba_input.shape[:2]
|
| 74 |
+
rgb = rgba_input[:, :, :3]
|
| 75 |
+
alpha = rgba_input[:, :, -1:]
|
| 76 |
+
|
| 77 |
+
zeros = np.where(alpha==0)
|
| 78 |
+
hh, ww = zeros[0], zeros[1]
|
| 79 |
+
h_min, h_max = hh.min(), hh.max()
|
| 80 |
+
w_min, w_max = ww.min(), ww.max()
|
| 81 |
+
|
| 82 |
+
# if the area already has enough padding
|
| 83 |
+
if h_max - h_min < padding_size and w_max - w_min < padding_size:
|
| 84 |
+
return rgba_input
|
| 85 |
+
|
| 86 |
+
padding_output = np.zeros((256, 256, 4))
|
| 87 |
+
padding_output[..., :3] = 1.0
|
| 88 |
+
|
| 89 |
+
padded_rgba = resize(rgba_input, padding_size)
|
| 90 |
+
new_h, new_w = padded_rgba.shape[:2]
|
| 91 |
+
|
| 92 |
+
padding_output[padding:padding+new_h, padding:padding+new_w, :] = padded_rgba
|
| 93 |
+
|
| 94 |
+
return padding_output
|
| 95 |
+
|
| 96 |
+
def shadow_composite(rgba, shadow, intensity, gamma):
|
| 97 |
+
rgb = rgba[..., :3]
|
| 98 |
+
mask = rgba[..., 3:]
|
| 99 |
+
|
| 100 |
+
if len(shadow.shape) == 2:
|
| 101 |
+
shadow = shadow[..., None]
|
| 102 |
+
|
| 103 |
+
new_shadow = 1.0 - shadow ** gamma * intensity
|
| 104 |
+
ret = rgb * mask + (1.0 - mask) * new_shadow
|
| 105 |
+
return ret, new_shadow[..., 0]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def render_btn_fn(mask, ibl):
|
| 109 |
+
global cur_rgba, cur_shadow, cur_gamma, cur_intensity
|
| 110 |
+
|
| 111 |
+
print("Button clicked!")
|
| 112 |
+
|
| 113 |
+
mask = mask / 255.0
|
| 114 |
+
ibl = ibl/ 255.0
|
| 115 |
+
|
| 116 |
+
# smoothing ibl
|
| 117 |
+
ibl = cv2.GaussianBlur(ibl, (11, 11), 0)
|
| 118 |
+
|
| 119 |
+
# padding mask
|
| 120 |
+
mask = padding_mask(mask)
|
| 121 |
+
|
| 122 |
+
cur_rgba = np.copy(mask)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
print('mask shape: {}/{}/{}/{}, ibl shape: {}/{}/{}/{}'.format(mask.shape, mask.dtype, mask.min(), mask.max(),
|
| 126 |
+
ibl.shape, ibl.dtype, ibl.min(), ibl.max()))
|
| 127 |
+
|
| 128 |
+
# ret = np.random.randn(256, 256, 3)
|
| 129 |
+
# ret = (ret - ret.min()) / (ret.max() - ret.min() + 1e-8)
|
| 130 |
+
|
| 131 |
+
rgb, mask = mask[..., :3], mask[..., 3]
|
| 132 |
+
|
| 133 |
+
ibl = ibl_normalize(cv2.resize(ibl, (32, 16)))
|
| 134 |
+
|
| 135 |
+
# ibl = 1.0 - ibl
|
| 136 |
+
|
| 137 |
+
x = {
|
| 138 |
+
'mask': mask,
|
| 139 |
+
'ibl': ibl
|
| 140 |
+
}
|
| 141 |
+
shadow = model.inference(x)
|
| 142 |
+
cur_shadow = np.copy(shadow)
|
| 143 |
+
|
| 144 |
+
# gamma
|
| 145 |
+
# shadow = np.power(shadow, 2.2)
|
| 146 |
+
# shadow = shadow * 0.8
|
| 147 |
+
# shadow = 1.0 - shadow
|
| 148 |
+
|
| 149 |
+
# composite the shadow
|
| 150 |
+
|
| 151 |
+
# shadow = shadow[..., None]
|
| 152 |
+
# mask = mask[..., None]
|
| 153 |
+
# ret = rgb * mask + (1.0 - mask) * shadow
|
| 154 |
+
ret, shadow = shadow_composite(cur_rgba, shadow, cur_intensity, cur_gamma)
|
| 155 |
+
|
| 156 |
+
# import pdb; pdb.set_trace()
|
| 157 |
+
# ret = (1.0-mask) * shadow
|
| 158 |
+
|
| 159 |
+
print('IBL range: {}/{} Shadow range: {} {}'.format(ibl.min(), ibl.max(), shadow.min(), shadow.max()))
|
| 160 |
+
|
| 161 |
+
plt.figure(figsize=(15, 10))
|
| 162 |
+
plt.subplot(1,3,1)
|
| 163 |
+
plt.imshow(mask)
|
| 164 |
+
plt.subplot(1,3,2)
|
| 165 |
+
plt.imshow(ibl)
|
| 166 |
+
plt.subplot(1,3,3)
|
| 167 |
+
plt.imshow(ret)
|
| 168 |
+
plt.savefig('tmp.png')
|
| 169 |
+
plt.close()
|
| 170 |
+
|
| 171 |
+
logging.info('Finished')
|
| 172 |
+
|
| 173 |
+
return ret, shadow
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def intensity_change(x):
|
| 177 |
+
global cur_rgba, cur_shadow, cur_gamma, cur_intensity
|
| 178 |
+
|
| 179 |
+
cur_intensity = x
|
| 180 |
+
ret, shadow = shadow_composite(cur_rgba, cur_shadow, cur_intensity, cur_gamma)
|
| 181 |
+
return ret, shadow
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def gamma_change(x):
|
| 185 |
+
global cur_rgba, cur_shadow, cur_gamma, cur_intensity
|
| 186 |
+
|
| 187 |
+
cur_gamma = x
|
| 188 |
+
ret, shadow = shadow_composite(cur_rgba, cur_shadow, cur_intensity, cur_gamma)
|
| 189 |
+
return ret, shadow
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
ibl_h = 128
|
| 193 |
+
ibl_w = ibl_h * 2
|
| 194 |
+
|
| 195 |
+
with gr.Blocks() as demo:
|
| 196 |
+
with gr.Row():
|
| 197 |
+
mask_input = gr.Image(shape=(256, 256), image_mode="RGBA", label="Mask")
|
| 198 |
+
ibl_input = gr.Sketchpad(shape=(ibl_w, ibl_h), image_mode="L", label="IBL", tool='sketch', invert_colors=True)
|
| 199 |
+
output = gr.Image(shape=(256, 256), height=256, width=256, image_mode="RGB", label="Output")
|
| 200 |
+
shadow_output = gr.Image(shape=(256, 256), height=256, width=256, image_mode="L", label="Shadow Layer")
|
| 201 |
+
|
| 202 |
+
with gr.Row():
|
| 203 |
+
intensity_slider = gr.Slider(0.0, 1.0, value=DEFAULT_INTENSITY, step=0.1, label="Intensity", info="Choose between 0.0 and 1.0")
|
| 204 |
+
gamma_slider = gr.Slider(1.0, 4.0, value=DEFAULT_GAMMA, step=0.1, label="Gamma", info="Gamma correction for shadow")
|
| 205 |
+
render_btn = gr.Button(label="Render")
|
| 206 |
+
|
| 207 |
+
render_btn.click(render_btn_fn, inputs=[mask_input, ibl_input], outputs=[output, shadow_output])
|
| 208 |
+
intensity_slider.release(intensity_change, inputs=[intensity_slider], outputs=[output, shadow_output])
|
| 209 |
+
gamma_slider.release(gamma_change, inputs=[gamma_slider], outputs=[output, shadow_output])
|
| 210 |
+
|
| 211 |
+
logging.info('Finished')
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
demo.launch()
|
configs/GSSN.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exp_name: GSSN_ALL_Channels_2e_5
|
| 2 |
+
|
| 3 |
+
# model related
|
| 4 |
+
model:
|
| 5 |
+
name: 'GSSN'
|
| 6 |
+
# backbone: 'vanilla'
|
| 7 |
+
backbone: 'SSN_v1'
|
| 8 |
+
in_channels: 6
|
| 9 |
+
out_channels: 1
|
| 10 |
+
resnet: True
|
| 11 |
+
|
| 12 |
+
mid_act: "gelu"
|
| 13 |
+
out_act: "gelu"
|
| 14 |
+
|
| 15 |
+
optimizer: 'Adam'
|
| 16 |
+
weight_decay: 4e-5
|
| 17 |
+
beta1: 0.9
|
| 18 |
+
|
| 19 |
+
focal: False
|
| 20 |
+
|
| 21 |
+
# dataset
|
| 22 |
+
dataset:
|
| 23 |
+
name: 'GSSN_Dataset'
|
| 24 |
+
hdf5_file: 'Dataset1/more_general_scenes/train/ALL_SIZE_WALL/dataset.hdf5'
|
| 25 |
+
type: 'BC_Boundary'
|
| 26 |
+
rech_grad: True
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
test_dataset:
|
| 30 |
+
name: 'GSSN_Testing_Dataset'
|
| 31 |
+
hdf5_file: 'Dataset/standalone_test_split/test/ALL_SIZE_MORE/dataset.hdf5'
|
| 32 |
+
type: 'BC_Boundary'
|
| 33 |
+
ignore_shading: True
|
| 34 |
+
rech_grad: True
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# training related
|
| 38 |
+
hyper_params:
|
| 39 |
+
lr: 2e-5
|
| 40 |
+
epochs: 100000
|
| 41 |
+
workers: 52
|
| 42 |
+
batch_size: 52
|
| 43 |
+
save_epoch: 10
|
| 44 |
+
|
| 45 |
+
eval_batch: 10
|
| 46 |
+
eval_save: False
|
| 47 |
+
|
| 48 |
+
# visualization
|
| 49 |
+
vis_iter: 100 # iteration for visualization
|
| 50 |
+
save_iter: 100
|
| 51 |
+
n_cols: 5
|
| 52 |
+
gpus:
|
| 53 |
+
- 0
|
| 54 |
+
default_folder: 'weights'
|
| 55 |
+
resume: False
|
| 56 |
+
# resume: True
|
| 57 |
+
weight_file: 'latest'
|
configs/SSN.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exp_name: SSN
|
| 2 |
+
|
| 3 |
+
# model related
|
| 4 |
+
model:
|
| 5 |
+
name: 'SSN'
|
| 6 |
+
in_channels: 1
|
| 7 |
+
out_channels: 1
|
| 8 |
+
resnet: False
|
| 9 |
+
|
| 10 |
+
mid_act: "relu"
|
| 11 |
+
out_act: 'relu'
|
| 12 |
+
|
| 13 |
+
optimizer: 'Adam'
|
| 14 |
+
weight_decay: 4e-5
|
| 15 |
+
beta1: 0.9
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# dataset
|
| 19 |
+
dataset:
|
| 20 |
+
name: 'SSN_Dataset'
|
| 21 |
+
hdf5_file: 'Dataset/SSN/ssn_shadow/shadow_base/ssn_base.hdf5'
|
| 22 |
+
shadow_per_epoch: 10
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# test_dataset:
|
| 26 |
+
# name: 'SSN_Dataset'
|
| 27 |
+
# hdf5_file: 'Dataset/SSN/ssn_shadow/shadow_base/ssn_base.hdf5'
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# training related
|
| 31 |
+
hyper_params:
|
| 32 |
+
lr: 1e-3
|
| 33 |
+
epochs: 100000
|
| 34 |
+
workers: 40
|
| 35 |
+
batch_size: 10
|
| 36 |
+
save_epoch: 10
|
| 37 |
+
|
| 38 |
+
eval_batch: 10
|
| 39 |
+
eval_save: False
|
| 40 |
+
|
| 41 |
+
# visualization
|
| 42 |
+
vis_iter: 100 # iteration for visualization
|
| 43 |
+
save_iter: 100
|
| 44 |
+
n_cols: 5
|
| 45 |
+
gpus:
|
| 46 |
+
- 0
|
| 47 |
+
- 1
|
| 48 |
+
|
| 49 |
+
default_folder: 'weights'
|
| 50 |
+
resume: False
|
| 51 |
+
weight_file: 'latest'
|
model_utils.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def parse_configs(config: str):
|
| 9 |
+
""" Parse the config file and return a dictionary of configs
|
| 10 |
+
|
| 11 |
+
:param config: path to the config file
|
| 12 |
+
:returns:
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
if not os.path.exists(config):
|
| 16 |
+
logging.error('Cannot find the config file: {}'.format(config))
|
| 17 |
+
exit()
|
| 18 |
+
|
| 19 |
+
with open(config, 'r') as stream:
|
| 20 |
+
try:
|
| 21 |
+
configs=yaml.safe_load(stream)
|
| 22 |
+
return configs
|
| 23 |
+
|
| 24 |
+
except yaml.YAMLError as exc:
|
| 25 |
+
logging.error(exc)
|
| 26 |
+
return {}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_model(config: str, weight: str, model_def, device):
|
| 30 |
+
""" Load the model from the config file and the weight file
|
| 31 |
+
|
| 32 |
+
:param config: path to the config file
|
| 33 |
+
:param weight: path to the weight file
|
| 34 |
+
:param model_def: model class definition
|
| 35 |
+
:param device: pytorch device
|
| 36 |
+
:returns:
|
| 37 |
+
|
| 38 |
+
"""
|
| 39 |
+
assert os.path.exists(weight), 'Cannot find the weight file: {}'.format(weight)
|
| 40 |
+
assert os.path.exists(config), 'Cannot find the config file: {}'.format(config)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
opt = parse_configs(config)
|
| 44 |
+
model = model_def(opt)
|
| 45 |
+
cp = torch.load(weight)
|
| 46 |
+
|
| 47 |
+
models = model.get_models()
|
| 48 |
+
for k, m in models.items():
|
| 49 |
+
m.load_state_dict(cp[k])
|
| 50 |
+
m.to(device)
|
| 51 |
+
|
| 52 |
+
model.set_models(models)
|
| 53 |
+
return model
|
models/Attention.ipynb
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 30,
|
| 6 |
+
"id": "9ba18e04-aa6b-44d8-bbcc-73417ededcfd",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import torch\n",
|
| 11 |
+
"import torch.nn as nn\n",
|
| 12 |
+
"import torch.nn.functional as F\n",
|
| 13 |
+
"from functools import partial\n",
|
| 14 |
+
"import math\n",
|
| 15 |
+
"import torch as th"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "code",
|
| 20 |
+
"execution_count": 31,
|
| 21 |
+
"id": "b273789d-9136-4c10-806d-12c19ff1ae68",
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"outputs": [],
|
| 24 |
+
"source": [
|
| 25 |
+
"class GroupNorm32(nn.GroupNorm):\n",
|
| 26 |
+
" def forward(self, x):\n",
|
| 27 |
+
" return super().forward(x.float()).type(x.dtype)\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"def normalization(channels):\n",
|
| 30 |
+
" \"\"\"\n",
|
| 31 |
+
" Make a standard normalization layer.\n",
|
| 32 |
+
" :param channels: number of input channels.\n",
|
| 33 |
+
" :return: an nn.Module for normalization.\n",
|
| 34 |
+
" \"\"\"\n",
|
| 35 |
+
" return GroupNorm32(32, channels)\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"def conv_nd(dims, *args, **kwargs):\n",
|
| 39 |
+
" \"\"\"\n",
|
| 40 |
+
" Create a 1D, 2D, or 3D convolution module.\n",
|
| 41 |
+
" \"\"\"\n",
|
| 42 |
+
" if dims == 1:\n",
|
| 43 |
+
" return nn.Conv1d(*args, **kwargs)\n",
|
| 44 |
+
" elif dims == 2:\n",
|
| 45 |
+
" return nn.Conv2d(*args, **kwargs)\n",
|
| 46 |
+
" elif dims == 3:\n",
|
| 47 |
+
" return nn.Conv3d(*args, **kwargs)\n",
|
| 48 |
+
" raise ValueError(f\"unsupported dimensions: {dims}\")\n"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": 32,
|
| 54 |
+
"id": "8ad13d44-7efc-4cf3-8f18-3c6ed4999963",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [],
|
| 57 |
+
"source": [
|
| 58 |
+
"class QKVAttentionLegacy(nn.Module):\n",
|
| 59 |
+
" \"\"\"\n",
|
| 60 |
+
" A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping\n",
|
| 61 |
+
" \"\"\"\n",
|
| 62 |
+
"\n",
|
| 63 |
+
" def __init__(self, n_heads):\n",
|
| 64 |
+
" super().__init__()\n",
|
| 65 |
+
" self.n_heads = n_heads\n",
|
| 66 |
+
"\n",
|
| 67 |
+
" def forward(self, qkv):\n",
|
| 68 |
+
" \"\"\"\n",
|
| 69 |
+
" Apply QKV attention.\n",
|
| 70 |
+
" :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.\n",
|
| 71 |
+
" :return: an [N x (H * C) x T] tensor after attention.\n",
|
| 72 |
+
" \"\"\"\n",
|
| 73 |
+
" bs, width, length = qkv.shape\n",
|
| 74 |
+
" assert width % (3 * self.n_heads) == 0\n",
|
| 75 |
+
" ch = width // (3 * self.n_heads)\n",
|
| 76 |
+
" q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)\n",
|
| 77 |
+
" scale = 1 / math.sqrt(math.sqrt(ch))\n",
|
| 78 |
+
" weight = th.einsum(\n",
|
| 79 |
+
" \"bct,bcs->bts\", q * scale, k * scale\n",
|
| 80 |
+
" ) # More stable with f16 than dividing afterwards\n",
|
| 81 |
+
" weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n",
|
| 82 |
+
" a = th.einsum(\"bts,bcs->bct\", weight, v)\n",
|
| 83 |
+
" return a.reshape(bs, -1, length)\n",
|
| 84 |
+
"\n",
|
| 85 |
+
" @staticmethod\n",
|
| 86 |
+
" def count_flops(model, _x, y):\n",
|
| 87 |
+
" return count_flops_attn(model, _x, y)"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": 33,
|
| 93 |
+
"id": "fd354430-2484-4f46-85f6-3397ae571fe9",
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"source": [
|
| 97 |
+
"def zero_module(module):\n",
|
| 98 |
+
" \"\"\"\n",
|
| 99 |
+
" Zero out the parameters of a module and return it.\n",
|
| 100 |
+
" \"\"\"\n",
|
| 101 |
+
" for p in module.parameters():\n",
|
| 102 |
+
" p.detach().zero_()\n",
|
| 103 |
+
" return module\n"
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"cell_type": "code",
|
| 108 |
+
"execution_count": 37,
|
| 109 |
+
"id": "af42604f-c5fe-467b-95e9-e376fe90d4a5",
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"source": [
|
| 113 |
+
"class AttentionBlock(nn.Module):\n",
|
| 114 |
+
" \"\"\"\n",
|
| 115 |
+
" An attention block that allows spatial positions to attend to each other.\n",
|
| 116 |
+
" Originally ported from here, but adapted to the N-d case.\n",
|
| 117 |
+
" https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n",
|
| 118 |
+
" \"\"\"\n",
|
| 119 |
+
"\n",
|
| 120 |
+
" def __init__(\n",
|
| 121 |
+
" self,\n",
|
| 122 |
+
" channels,\n",
|
| 123 |
+
" num_heads=1,\n",
|
| 124 |
+
" num_head_channels=-1,\n",
|
| 125 |
+
" use_new_attention_order=False,\n",
|
| 126 |
+
" ):\n",
|
| 127 |
+
" super().__init__()\n",
|
| 128 |
+
" self.channels = channels\n",
|
| 129 |
+
" if num_head_channels == -1:\n",
|
| 130 |
+
" self.num_heads = num_heads\n",
|
| 131 |
+
" else:\n",
|
| 132 |
+
" assert (\n",
|
| 133 |
+
" channels % num_head_channels == 0\n",
|
| 134 |
+
" ), f\"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}\"\n",
|
| 135 |
+
" self.num_heads = channels // num_head_channels\n",
|
| 136 |
+
" self.norm = normalization(channels)\n",
|
| 137 |
+
" self.qkv = conv_nd(1, channels, channels * 3, 1)\n",
|
| 138 |
+
" if use_new_attention_order:\n",
|
| 139 |
+
" # split qkv before split heads\n",
|
| 140 |
+
" self.attention = QKVAttention(self.num_heads)\n",
|
| 141 |
+
" else:\n",
|
| 142 |
+
" # split heads before split qkv\n",
|
| 143 |
+
" self.attention = QKVAttentionLegacy(self.num_heads)\n",
|
| 144 |
+
"\n",
|
| 145 |
+
" self.proj_out = zero_module(conv_nd(1, channels, channels, 1))\n",
|
| 146 |
+
"\n",
|
| 147 |
+
" def forward(self, x):\n",
|
| 148 |
+
" \n",
|
| 149 |
+
" import pdb; pdb.set_trace()\n",
|
| 150 |
+
" \n",
|
| 151 |
+
" b, c, *spatial = x.shape\n",
|
| 152 |
+
" x = x.reshape(b, c, -1)\n",
|
| 153 |
+
" qkv = self.qkv(self.norm(x))\n",
|
| 154 |
+
" h = self.attention(qkv)\n",
|
| 155 |
+
" h = self.proj_out(h)\n",
|
| 156 |
+
" return (x + h).reshape(b, c, *spatial)"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"cell_type": "code",
|
| 161 |
+
"execution_count": 38,
|
| 162 |
+
"id": "7b180b84-f22c-446b-b2da-0fa987274953",
|
| 163 |
+
"metadata": {},
|
| 164 |
+
"outputs": [
|
| 165 |
+
{
|
| 166 |
+
"name": "stdout",
|
| 167 |
+
"output_type": "stream",
|
| 168 |
+
"text": [
|
| 169 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(39)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
| 170 |
+
"\u001b[0;32m 37 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 171 |
+
"\u001b[0m\u001b[0;32m 38 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 172 |
+
"\u001b[0m\u001b[0;32m---> 39 \u001b[0;31m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 173 |
+
"\u001b[0m\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 174 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 175 |
+
"\u001b[0m\n"
|
| 176 |
+
]
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"name": "stdin",
|
| 180 |
+
"output_type": "stream",
|
| 181 |
+
"text": [
|
| 182 |
+
"ipdb> n\n"
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"name": "stdout",
|
| 187 |
+
"output_type": "stream",
|
| 188 |
+
"text": [
|
| 189 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(40)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
| 190 |
+
"\u001b[0;32m 38 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 191 |
+
"\u001b[0m\u001b[0;32m 39 \u001b[0;31m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 192 |
+
"\u001b[0m\u001b[0;32m---> 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 193 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 194 |
+
"\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 195 |
+
"\u001b[0m\n"
|
| 196 |
+
]
|
| 197 |
+
},
|
| 198 |
+
{
|
| 199 |
+
"name": "stdin",
|
| 200 |
+
"output_type": "stream",
|
| 201 |
+
"text": [
|
| 202 |
+
"ipdb> n\n"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"name": "stdout",
|
| 207 |
+
"output_type": "stream",
|
| 208 |
+
"text": [
|
| 209 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(41)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
| 210 |
+
"\u001b[0;32m 39 \u001b[0;31m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 211 |
+
"\u001b[0m\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 212 |
+
"\u001b[0m\u001b[0;32m---> 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 213 |
+
"\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 214 |
+
"\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 215 |
+
"\u001b[0m\n"
|
| 216 |
+
]
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"name": "stdin",
|
| 220 |
+
"output_type": "stream",
|
| 221 |
+
"text": [
|
| 222 |
+
"ipdb> x.shape\n"
|
| 223 |
+
]
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"name": "stdout",
|
| 227 |
+
"output_type": "stream",
|
| 228 |
+
"text": [
|
| 229 |
+
"torch.Size([5, 32, 16384])\n"
|
| 230 |
+
]
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"name": "stdin",
|
| 234 |
+
"output_type": "stream",
|
| 235 |
+
"text": [
|
| 236 |
+
"ipdb> t = self.norm(x)\n",
|
| 237 |
+
"ipdb> t.shape\n"
|
| 238 |
+
]
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"name": "stdout",
|
| 242 |
+
"output_type": "stream",
|
| 243 |
+
"text": [
|
| 244 |
+
"torch.Size([5, 32, 16384])\n"
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"name": "stdin",
|
| 249 |
+
"output_type": "stream",
|
| 250 |
+
"text": [
|
| 251 |
+
"ipdb> self.qkv\n"
|
| 252 |
+
]
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"name": "stdout",
|
| 256 |
+
"output_type": "stream",
|
| 257 |
+
"text": [
|
| 258 |
+
"Conv1d(32, 96, kernel_size=(1,), stride=(1,))\n"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"name": "stdin",
|
| 263 |
+
"output_type": "stream",
|
| 264 |
+
"text": [
|
| 265 |
+
"ipdb> n\n"
|
| 266 |
+
]
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"name": "stdout",
|
| 270 |
+
"output_type": "stream",
|
| 271 |
+
"text": [
|
| 272 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(42)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
| 273 |
+
"\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 274 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 275 |
+
"\u001b[0m\u001b[0;32m---> 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 276 |
+
"\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 277 |
+
"\u001b[0m\u001b[0;32m 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 278 |
+
"\u001b[0m\n"
|
| 279 |
+
]
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"name": "stdin",
|
| 283 |
+
"output_type": "stream",
|
| 284 |
+
"text": [
|
| 285 |
+
"ipdb> qkv.shape\n"
|
| 286 |
+
]
|
| 287 |
+
},
|
| 288 |
+
{
|
| 289 |
+
"name": "stdout",
|
| 290 |
+
"output_type": "stream",
|
| 291 |
+
"text": [
|
| 292 |
+
"torch.Size([5, 96, 16384])\n"
|
| 293 |
+
]
|
| 294 |
+
},
|
| 295 |
+
{
|
| 296 |
+
"name": "stdin",
|
| 297 |
+
"output_type": "stream",
|
| 298 |
+
"text": [
|
| 299 |
+
"ipdb> t.shape\n"
|
| 300 |
+
]
|
| 301 |
+
},
|
| 302 |
+
{
|
| 303 |
+
"name": "stdout",
|
| 304 |
+
"output_type": "stream",
|
| 305 |
+
"text": [
|
| 306 |
+
"torch.Size([5, 32, 16384])\n"
|
| 307 |
+
]
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
"name": "stdin",
|
| 311 |
+
"output_type": "stream",
|
| 312 |
+
"text": [
|
| 313 |
+
"ipdb> n\n"
|
| 314 |
+
]
|
| 315 |
+
},
|
| 316 |
+
{
|
| 317 |
+
"name": "stdout",
|
| 318 |
+
"output_type": "stream",
|
| 319 |
+
"text": [
|
| 320 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(43)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
| 321 |
+
"\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 322 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 323 |
+
"\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 324 |
+
"\u001b[0m\u001b[0;32m---> 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 325 |
+
"\u001b[0m\u001b[0;32m 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 326 |
+
"\u001b[0m\n"
|
| 327 |
+
]
|
| 328 |
+
},
|
| 329 |
+
{
|
| 330 |
+
"name": "stdin",
|
| 331 |
+
"output_type": "stream",
|
| 332 |
+
"text": [
|
| 333 |
+
"ipdb> h.shape\n"
|
| 334 |
+
]
|
| 335 |
+
},
|
| 336 |
+
{
|
| 337 |
+
"name": "stdout",
|
| 338 |
+
"output_type": "stream",
|
| 339 |
+
"text": [
|
| 340 |
+
"*** No help for '.shape'\n"
|
| 341 |
+
]
|
| 342 |
+
},
|
| 343 |
+
{
|
| 344 |
+
"name": "stdin",
|
| 345 |
+
"output_type": "stream",
|
| 346 |
+
"text": [
|
| 347 |
+
"ipdb> h.shape\n"
|
| 348 |
+
]
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"name": "stdout",
|
| 352 |
+
"output_type": "stream",
|
| 353 |
+
"text": [
|
| 354 |
+
"*** No help for '.shape'\n"
|
| 355 |
+
]
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"name": "stdin",
|
| 359 |
+
"output_type": "stream",
|
| 360 |
+
"text": [
|
| 361 |
+
"ipdb> print(h.shape)\n"
|
| 362 |
+
]
|
| 363 |
+
},
|
| 364 |
+
{
|
| 365 |
+
"name": "stdout",
|
| 366 |
+
"output_type": "stream",
|
| 367 |
+
"text": [
|
| 368 |
+
"torch.Size([5, 32, 16384])\n"
|
| 369 |
+
]
|
| 370 |
+
},
|
| 371 |
+
{
|
| 372 |
+
"name": "stdin",
|
| 373 |
+
"output_type": "stream",
|
| 374 |
+
"text": [
|
| 375 |
+
"ipdb> self.proj_out\n"
|
| 376 |
+
]
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"name": "stdout",
|
| 380 |
+
"output_type": "stream",
|
| 381 |
+
"text": [
|
| 382 |
+
"Conv1d(32, 32, kernel_size=(1,), stride=(1,))\n"
|
| 383 |
+
]
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"name": "stdin",
|
| 387 |
+
"output_type": "stream",
|
| 388 |
+
"text": [
|
| 389 |
+
"ipdb> n\n"
|
| 390 |
+
]
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"name": "stdout",
|
| 394 |
+
"output_type": "stream",
|
| 395 |
+
"text": [
|
| 396 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(44)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
| 397 |
+
"\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 398 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 399 |
+
"\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 400 |
+
"\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 401 |
+
"\u001b[0m\u001b[0;32m---> 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 402 |
+
"\u001b[0m\n"
|
| 403 |
+
]
|
| 404 |
+
},
|
| 405 |
+
{
|
| 406 |
+
"name": "stdin",
|
| 407 |
+
"output_type": "stream",
|
| 408 |
+
"text": [
|
| 409 |
+
"ipdb> \n"
|
| 410 |
+
]
|
| 411 |
+
},
|
| 412 |
+
{
|
| 413 |
+
"name": "stdout",
|
| 414 |
+
"output_type": "stream",
|
| 415 |
+
"text": [
|
| 416 |
+
"--Return--\n",
|
| 417 |
+
"tensor([[[[ 1...iasBackward0>)\n",
|
| 418 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(44)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
| 419 |
+
"\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 420 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 421 |
+
"\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 422 |
+
"\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 423 |
+
"\u001b[0m\u001b[0;32m---> 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 424 |
+
"\u001b[0m\n"
|
| 425 |
+
]
|
| 426 |
+
},
|
| 427 |
+
{
|
| 428 |
+
"name": "stdin",
|
| 429 |
+
"output_type": "stream",
|
| 430 |
+
"text": [
|
| 431 |
+
"ipdb> q\n"
|
| 432 |
+
]
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"ename": "BdbQuit",
|
| 436 |
+
"evalue": "",
|
| 437 |
+
"output_type": "error",
|
| 438 |
+
"traceback": [
|
| 439 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 440 |
+
"\u001b[0;31mBdbQuit\u001b[0m Traceback (most recent call last)",
|
| 441 |
+
"\u001b[0;32m/tmp/ipykernel_456404/1120562961.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mAttentionBlock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
| 442 |
+
"\u001b[0;32m~/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 443 |
+
"\u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
| 444 |
+
"\u001b[0;32m~/anaconda3/envs/py38/lib/python3.8/bdb.py\u001b[0m in \u001b[0;36mtrace_dispatch\u001b[0;34m(self, frame, event, arg)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'return'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_return\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 93\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'exception'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 445 |
+
"\u001b[0;32m~/anaconda3/envs/py38/lib/python3.8/bdb.py\u001b[0m in \u001b[0;36mdispatch_return\u001b[0;34m(self, frame, arg)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mframe_returning\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 154\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquitting\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mBdbQuit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 155\u001b[0m \u001b[0;31m# The user issued a 'next' or 'until' command.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopframe\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mframe\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstoplineno\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 446 |
+
"\u001b[0;31mBdbQuit\u001b[0m: "
|
| 447 |
+
]
|
| 448 |
+
}
|
| 449 |
+
],
|
| 450 |
+
"source": [
|
| 451 |
+
"test_input = torch.randn(5, 32, 128, 128)\n",
|
| 452 |
+
"\n",
|
| 453 |
+
"model = AttentionBlock(32, 1)\n",
|
| 454 |
+
"\n",
|
| 455 |
+
"y = model(test_input)"
|
| 456 |
+
]
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"cell_type": "code",
|
| 460 |
+
"execution_count": 36,
|
| 461 |
+
"id": "3109500e-146d-46c4-8709-6a1e8d24e4ac",
|
| 462 |
+
"metadata": {},
|
| 463 |
+
"outputs": [
|
| 464 |
+
{
|
| 465 |
+
"data": {
|
| 466 |
+
"text/plain": [
|
| 467 |
+
"torch.Size([5, 32, 128, 128])"
|
| 468 |
+
]
|
| 469 |
+
},
|
| 470 |
+
"execution_count": 36,
|
| 471 |
+
"metadata": {},
|
| 472 |
+
"output_type": "execute_result"
|
| 473 |
+
}
|
| 474 |
+
],
|
| 475 |
+
"source": [
|
| 476 |
+
"y.shape"
|
| 477 |
+
]
|
| 478 |
+
},
|
| 479 |
+
{
|
| 480 |
+
"cell_type": "code",
|
| 481 |
+
"execution_count": null,
|
| 482 |
+
"id": "0c916f9c-5dba-499d-99ea-e56f2855c9cc",
|
| 483 |
+
"metadata": {},
|
| 484 |
+
"outputs": [],
|
| 485 |
+
"source": []
|
| 486 |
+
}
|
| 487 |
+
],
|
| 488 |
+
"metadata": {
|
| 489 |
+
"kernelspec": {
|
| 490 |
+
"display_name": "Python 3 (ipykernel)",
|
| 491 |
+
"language": "python",
|
| 492 |
+
"name": "python3"
|
| 493 |
+
},
|
| 494 |
+
"language_info": {
|
| 495 |
+
"codemirror_mode": {
|
| 496 |
+
"name": "ipython",
|
| 497 |
+
"version": 3
|
| 498 |
+
},
|
| 499 |
+
"file_extension": ".py",
|
| 500 |
+
"mimetype": "text/x-python",
|
| 501 |
+
"name": "python",
|
| 502 |
+
"nbconvert_exporter": "python",
|
| 503 |
+
"pygments_lexer": "ipython3",
|
| 504 |
+
"version": "3.8.12"
|
| 505 |
+
}
|
| 506 |
+
},
|
| 507 |
+
"nbformat": 4,
|
| 508 |
+
"nbformat_minor": 5
|
| 509 |
+
}
|
models/Attention_SSN.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Iterable
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from .SSN import Conv, Conv2DMod, Decoder, Up
|
| 13 |
+
from .attention import AttentionBlock
|
| 14 |
+
from .blocks import ResBlock, Res_Type, get_activation
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Attention_Encoder(nn.Module):
|
| 18 |
+
def __init__(self, in_channels=3, mid_act='gelu', dropout=0.0, num_heads=8, resnet=True):
|
| 19 |
+
super(Attention_Encoder, self).__init__()
|
| 20 |
+
|
| 21 |
+
self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
|
| 22 |
+
self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
|
| 23 |
+
self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
|
| 24 |
+
|
| 25 |
+
self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
|
| 26 |
+
self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
|
| 27 |
+
|
| 28 |
+
self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
|
| 29 |
+
self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
|
| 30 |
+
self.down_256_256_1_attn = AttentionBlock(256, num_heads)
|
| 31 |
+
|
| 32 |
+
self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
|
| 33 |
+
self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 34 |
+
self.down_512_512_1_attn = AttentionBlock(512, num_heads)
|
| 35 |
+
|
| 36 |
+
self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 37 |
+
self.down_512_512_2_attn = AttentionBlock(512, num_heads)
|
| 38 |
+
|
| 39 |
+
self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 40 |
+
self.down_512_512_3_attn = AttentionBlock(512, num_heads)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
x1 = self.in_conv(x) # 32 x 256 x 256
|
| 45 |
+
x1 = torch.cat((x, x1), dim=1)
|
| 46 |
+
|
| 47 |
+
x2 = self.down_32_64(x1)
|
| 48 |
+
x3 = self.down_64_64_1(x2)
|
| 49 |
+
|
| 50 |
+
x4 = self.down_64_128(x3)
|
| 51 |
+
x5 = self.down_128_128_1(x4)
|
| 52 |
+
|
| 53 |
+
x6 = self.down_128_256(x5)
|
| 54 |
+
x7 = self.down_256_256_1(x6)
|
| 55 |
+
x7 = self.down_256_256_1_attn(x7)
|
| 56 |
+
|
| 57 |
+
x8 = self.down_256_512(x7)
|
| 58 |
+
x9 = self.down_512_512_1(x8)
|
| 59 |
+
x9 = self.down_512_512_1_attn(x9)
|
| 60 |
+
|
| 61 |
+
x10 = self.down_512_512_2(x9)
|
| 62 |
+
x10 = self.down_512_512_2_attn(x10)
|
| 63 |
+
|
| 64 |
+
x11 = self.down_512_512_3(x10)
|
| 65 |
+
x11 = self.down_512_512_3_attn(x11)
|
| 66 |
+
|
| 67 |
+
return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Attention_Decoder(nn.Module):
|
| 71 |
+
def __init__(self, out_channels=3, mid_act='gelu', out_act='sigmoid', resnet = True, num_heads=8):
|
| 72 |
+
|
| 73 |
+
super(Attention_Decoder, self).__init__()
|
| 74 |
+
|
| 75 |
+
input_channel = 512
|
| 76 |
+
fea_dim = 100
|
| 77 |
+
|
| 78 |
+
self.to_style1 = nn.Linear(in_features=fea_dim, out_features=input_channel)
|
| 79 |
+
|
| 80 |
+
self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, style=True, resnet=resnet)
|
| 81 |
+
self.up_16_16_1_attn = AttentionBlock(256, num_heads=num_heads)
|
| 82 |
+
|
| 83 |
+
self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
|
| 84 |
+
self.up_16_16_2_attn = AttentionBlock(512, num_heads=num_heads)
|
| 85 |
+
|
| 86 |
+
self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
|
| 87 |
+
self.up_16_16_3_attn = AttentionBlock(512, num_heads=num_heads)
|
| 88 |
+
|
| 89 |
+
self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
|
| 90 |
+
self.to_style2 = nn.Linear(in_features=fea_dim, out_features=512)
|
| 91 |
+
self.up_32_32_1 = Conv(512, 256, activation=mid_act, style=True, resnet=resnet)
|
| 92 |
+
self.up_32_32_1_attn = AttentionBlock(256, num_heads=num_heads)
|
| 93 |
+
|
| 94 |
+
self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
|
| 95 |
+
self.to_style3 = nn.Linear(in_features=fea_dim, out_features=256)
|
| 96 |
+
self.up_64_64_1 = Conv(256, 128, activation=mid_act, style=True, resnet=resnet)
|
| 97 |
+
|
| 98 |
+
self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
|
| 99 |
+
self.to_style4 = nn.Linear(in_features=fea_dim, out_features=128)
|
| 100 |
+
self.up_128_128_1 = Conv(128, 64, activation=mid_act, style=True, resnet=resnet)
|
| 101 |
+
|
| 102 |
+
self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
|
| 103 |
+
self.out_conv = Conv(64, out_channels, activation=out_act)
|
| 104 |
+
self.out_act = get_activation(out_act)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def forward(self, x, style):
|
| 108 |
+
x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
|
| 109 |
+
|
| 110 |
+
style1 = self.to_style1(style)
|
| 111 |
+
y = self.up_16_16_1(x11, style1) # 256 x 16 x 16
|
| 112 |
+
y = self.up_16_16_1_attn(y)
|
| 113 |
+
|
| 114 |
+
y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
|
| 115 |
+
y = self.up_16_16_2(y, y) # 512 x 16 x 16
|
| 116 |
+
y = self.up_16_16_2_attn(y)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
|
| 120 |
+
y = self.up_16_16_3(y, y) # 512 x 16 x 16
|
| 121 |
+
y = self.up_16_16_3_attn(y)
|
| 122 |
+
|
| 123 |
+
y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
|
| 124 |
+
y = self.up_16_32(y, y) # 256 x 32 x 32
|
| 125 |
+
|
| 126 |
+
y = torch.cat((x7, y), dim=1)
|
| 127 |
+
style2 = self.to_style2(style)
|
| 128 |
+
y = self.up_32_32_1(y, style2) # 256 x 32 x 32
|
| 129 |
+
y = self.up_32_32_1_attn(y)
|
| 130 |
+
|
| 131 |
+
y = torch.cat((x6, y), dim=1)
|
| 132 |
+
y = self.up_32_64(y, y)
|
| 133 |
+
|
| 134 |
+
y = torch.cat((x5, y), dim=1)
|
| 135 |
+
style3 = self.to_style3(style)
|
| 136 |
+
|
| 137 |
+
y = self.up_64_64_1(y, style3) # 128 x 64 x 64
|
| 138 |
+
|
| 139 |
+
y = torch.cat((x4, y), dim=1)
|
| 140 |
+
y = self.up_64_128(y, y)
|
| 141 |
+
|
| 142 |
+
y = torch.cat((x3, y), dim=1)
|
| 143 |
+
style4 = self.to_style4(style)
|
| 144 |
+
y = self.up_128_128_1(y, style4) # 64 x 128 x 128
|
| 145 |
+
|
| 146 |
+
y = torch.cat((x2, y), dim=1)
|
| 147 |
+
y = self.up_128_256(y, y) # 32 x 256 x 256
|
| 148 |
+
|
| 149 |
+
y = torch.cat((x1, y), dim=1)
|
| 150 |
+
y = self.out_conv(y, y) # 3 x 256 x 256
|
| 151 |
+
y = self.out_act(y)
|
| 152 |
+
return y
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Attention_SSN(nn.Module):
|
| 157 |
+
def __init__(self, in_channels, out_channels, num_heads=8, resnet=True, mid_act='gelu', out_act='gelu'):
|
| 158 |
+
super(Attention_SSN, self).__init__()
|
| 159 |
+
self.encoder = Attention_Encoder(in_channels, mid_act, num_heads, resnet)
|
| 160 |
+
self.decoder = Attention_Decoder(out_channels, mid_act, out_act, resnet)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def forward(self, x, softness):
|
| 164 |
+
latent = self.encoder(x)
|
| 165 |
+
pred = self.decoder(latent, softness)
|
| 166 |
+
|
| 167 |
+
return pred
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def get_model_size(model):
|
| 171 |
+
param_size = 0
|
| 172 |
+
import pdb; pdb.set_trace()
|
| 173 |
+
for param in model.parameters():
|
| 174 |
+
param_size += param.nelement() * param.element_size()
|
| 175 |
+
|
| 176 |
+
buffer_size = 0
|
| 177 |
+
for buffer in model.buffers():
|
| 178 |
+
buffer_size += buffer.nelement() * buffer.element_size()
|
| 179 |
+
|
| 180 |
+
size_all_mb = (param_size + buffer_size) / 1024 ** 2
|
| 181 |
+
print('model size: {:.3f}MB'.format(size_all_mb))
|
| 182 |
+
# return param_size + buffer_size
|
| 183 |
+
return size_all_mb
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == '__main__':
|
| 187 |
+
model = AttentionBlock(in_channels=256, num_heads=8)
|
| 188 |
+
x = torch.randn(5, 256, 64, 64)
|
| 189 |
+
|
| 190 |
+
y = model(x)
|
| 191 |
+
print('{}, {}'.format(x.shape, y.shape))
|
| 192 |
+
|
| 193 |
+
# ------------------------------------------------------------------ #
|
| 194 |
+
in_channels = 3
|
| 195 |
+
out_channels = 1
|
| 196 |
+
num_heads = 8
|
| 197 |
+
resnet = True
|
| 198 |
+
mid_act = 'gelu'
|
| 199 |
+
out_act = 'gelu'
|
| 200 |
+
|
| 201 |
+
model = Attention_SSN(in_channels=in_channels,
|
| 202 |
+
out_channels=out_channels,
|
| 203 |
+
num_heads=num_heads,
|
| 204 |
+
resnet=resnet,
|
| 205 |
+
mid_act=mid_act,
|
| 206 |
+
out_act=out_act)
|
| 207 |
+
|
| 208 |
+
x = torch.randn(5, 3, 256, 256)
|
| 209 |
+
softness = torch.randn(5, 100)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
y = model(x, softness)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
print('x: {}, y: {}'.format(x.shape, y.shape))
|
| 216 |
+
|
| 217 |
+
get_model_size(model)
|
| 218 |
+
# ------------------------------------------------------------------ #
|
models/Attention_Unet.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from .SSN import Conv, Conv2DMod, Decoder, Up
|
| 8 |
+
from .attention import AttentionBlock
|
| 9 |
+
from .blocks import ResBlock, Res_Type, get_activation
|
| 10 |
+
|
| 11 |
+
class Attention_Encoder(nn.Module):
|
| 12 |
+
def __init__(self, in_channels=3, mid_act='gelu', dropout=0.0, num_heads=8, resnet=True):
|
| 13 |
+
super(Attention_Encoder, self).__init__()
|
| 14 |
+
|
| 15 |
+
self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
|
| 16 |
+
self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
|
| 17 |
+
self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
|
| 18 |
+
|
| 19 |
+
self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
|
| 20 |
+
self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
|
| 21 |
+
|
| 22 |
+
self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
|
| 23 |
+
self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
|
| 24 |
+
self.down_256_256_1_attn = AttentionBlock(256, num_heads)
|
| 25 |
+
|
| 26 |
+
self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
|
| 27 |
+
self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 28 |
+
self.down_512_512_1_attn = AttentionBlock(512, num_heads)
|
| 29 |
+
|
| 30 |
+
self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 31 |
+
self.down_512_512_2_attn = AttentionBlock(512, num_heads)
|
| 32 |
+
|
| 33 |
+
self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 34 |
+
self.down_512_512_3_attn = AttentionBlock(512, num_heads)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
x1 = self.in_conv(x) # 32 x 256 x 256
|
| 39 |
+
x1 = torch.cat((x, x1), dim=1)
|
| 40 |
+
|
| 41 |
+
x2 = self.down_32_64(x1)
|
| 42 |
+
x3 = self.down_64_64_1(x2)
|
| 43 |
+
|
| 44 |
+
x4 = self.down_64_128(x3)
|
| 45 |
+
x5 = self.down_128_128_1(x4)
|
| 46 |
+
|
| 47 |
+
x6 = self.down_128_256(x5)
|
| 48 |
+
x7 = self.down_256_256_1(x6)
|
| 49 |
+
x7 = self.down_256_256_1_attn(x7)
|
| 50 |
+
|
| 51 |
+
x8 = self.down_256_512(x7)
|
| 52 |
+
x9 = self.down_512_512_1(x8)
|
| 53 |
+
x9 = self.down_512_512_1_attn(x9)
|
| 54 |
+
|
| 55 |
+
x10 = self.down_512_512_2(x9)
|
| 56 |
+
x10 = self.down_512_512_2_attn(x10)
|
| 57 |
+
|
| 58 |
+
x11 = self.down_512_512_3(x10)
|
| 59 |
+
x11 = self.down_512_512_3_attn(x11)
|
| 60 |
+
|
| 61 |
+
return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Attention_Decoder(nn.Module):
|
| 65 |
+
def __init__(self, out_channels=3, mid_act='gelu', out_act='sigmoid', resnet = True, num_heads=8):
|
| 66 |
+
|
| 67 |
+
super(Attention_Decoder, self).__init__()
|
| 68 |
+
|
| 69 |
+
input_channel = 512
|
| 70 |
+
fea_dim = 100
|
| 71 |
+
|
| 72 |
+
self.to_style1 = nn.Linear(in_features=fea_dim, out_features=input_channel)
|
| 73 |
+
|
| 74 |
+
self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, style=False, resnet=resnet)
|
| 75 |
+
self.up_16_16_1_attn = AttentionBlock(256, num_heads=num_heads)
|
| 76 |
+
|
| 77 |
+
self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
|
| 78 |
+
self.up_16_16_2_attn = AttentionBlock(512, num_heads=num_heads)
|
| 79 |
+
|
| 80 |
+
self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
|
| 81 |
+
self.up_16_16_3_attn = AttentionBlock(512, num_heads=num_heads)
|
| 82 |
+
|
| 83 |
+
self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
|
| 84 |
+
self.to_style2 = nn.Linear(in_features=fea_dim, out_features=512)
|
| 85 |
+
self.up_32_32_1 = Conv(512, 256, activation=mid_act, style=False, resnet=resnet)
|
| 86 |
+
self.up_32_32_1_attn = AttentionBlock(256, num_heads=num_heads)
|
| 87 |
+
|
| 88 |
+
self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
|
| 89 |
+
self.to_style3 = nn.Linear(in_features=fea_dim, out_features=256)
|
| 90 |
+
self.up_64_64_1 = Conv(256, 128, activation=mid_act, style=False, resnet=resnet)
|
| 91 |
+
|
| 92 |
+
self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
|
| 93 |
+
self.to_style4 = nn.Linear(in_features=fea_dim, out_features=128)
|
| 94 |
+
self.up_128_128_1 = Conv(128, 64, activation=mid_act, style=False, resnet=resnet)
|
| 95 |
+
|
| 96 |
+
self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
|
| 97 |
+
self.out_conv = Conv(64, out_channels, activation=out_act)
|
| 98 |
+
self.out_act = get_activation(out_act)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
|
| 103 |
+
|
| 104 |
+
y = self.up_16_16_1(x11) # 256 x 16 x 16
|
| 105 |
+
y = self.up_16_16_1_attn(y)
|
| 106 |
+
|
| 107 |
+
y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
|
| 108 |
+
y = self.up_16_16_2(y, y) # 512 x 16 x 16
|
| 109 |
+
y = self.up_16_16_2_attn(y)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
|
| 113 |
+
y = self.up_16_16_3(y, y) # 512 x 16 x 16
|
| 114 |
+
y = self.up_16_16_3_attn(y)
|
| 115 |
+
|
| 116 |
+
y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
|
| 117 |
+
y = self.up_16_32(y, y) # 256 x 32 x 32
|
| 118 |
+
|
| 119 |
+
y = torch.cat((x7, y), dim=1)
|
| 120 |
+
y = self.up_32_32_1(y) # 256 x 32 x 32
|
| 121 |
+
y = self.up_32_32_1_attn(y)
|
| 122 |
+
|
| 123 |
+
y = torch.cat((x6, y), dim=1)
|
| 124 |
+
y = self.up_32_64(y, y)
|
| 125 |
+
|
| 126 |
+
y = torch.cat((x5, y), dim=1)
|
| 127 |
+
|
| 128 |
+
y = self.up_64_64_1(y) # 128 x 64 x 64
|
| 129 |
+
|
| 130 |
+
y = torch.cat((x4, y), dim=1)
|
| 131 |
+
y = self.up_64_128(y, y)
|
| 132 |
+
|
| 133 |
+
y = torch.cat((x3, y), dim=1)
|
| 134 |
+
y = self.up_128_128_1(y) # 64 x 128 x 128
|
| 135 |
+
|
| 136 |
+
y = torch.cat((x2, y), dim=1)
|
| 137 |
+
y = self.up_128_256(y, y) # 32 x 256 x 256
|
| 138 |
+
|
| 139 |
+
y = torch.cat((x1, y), dim=1)
|
| 140 |
+
y = self.out_conv(y, y) # 3 x 256 x 256
|
| 141 |
+
y = self.out_act(y)
|
| 142 |
+
return y
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Attention_Unet(nn.Module):
|
| 146 |
+
def __init__(self, in_channels, out_channels, num_heads=8, resnet=True, mid_act='gelu', out_act='gelu'):
|
| 147 |
+
super(Attention_Unet, self).__init__()
|
| 148 |
+
self.encoder = Attention_Encoder(in_channels, mid_act, num_heads, resnet)
|
| 149 |
+
self.decoder = Attention_Decoder(out_channels, mid_act, out_act, resnet)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
latent = self.encoder(x)
|
| 154 |
+
pred = self.decoder(latent)
|
| 155 |
+
return pred
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == '__main__':
|
| 159 |
+
test_input = torch.randn(5, 1, 256, 256)
|
| 160 |
+
style = torch.randn(5, 100)
|
| 161 |
+
|
| 162 |
+
model = SSN_v1(1, 1, mid_act='gelu', out_act='gelu', resnet=True)
|
| 163 |
+
test_out = model(test_input, style)
|
| 164 |
+
|
| 165 |
+
print('Ouptut shape: ', test_out.shape)
|
models/GSSN.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision import utils
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.cm as cm
|
| 8 |
+
import matplotlib as mpl
|
| 9 |
+
|
| 10 |
+
from .abs_model import abs_model
|
| 11 |
+
from .blocks import *
|
| 12 |
+
from .SSN import SSN
|
| 13 |
+
from .SSN_v1 import SSN_v1
|
| 14 |
+
from .Loss.Loss import norm_loss
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GSSN(abs_model):
|
| 18 |
+
def __init__(self, opt):
|
| 19 |
+
mid_act = opt['model']['mid_act']
|
| 20 |
+
out_act = opt['model']['out_act']
|
| 21 |
+
in_channels = opt['model']['in_channels']
|
| 22 |
+
out_channels = opt['model']['out_channels']
|
| 23 |
+
resnet = opt['model']['resnet']
|
| 24 |
+
self.ncols = opt['hyper_params']['n_cols']
|
| 25 |
+
self.focal = opt['model']['focal']
|
| 26 |
+
|
| 27 |
+
if 'backbone' not in opt['model'].keys():
|
| 28 |
+
self.model = SSN(in_channels=in_channels,
|
| 29 |
+
out_channels=out_channels,
|
| 30 |
+
mid_act=mid_act,
|
| 31 |
+
out_act=out_act,
|
| 32 |
+
resnet=resnet)
|
| 33 |
+
|
| 34 |
+
else:
|
| 35 |
+
backbone = opt['model']['backbone']
|
| 36 |
+
if backbone == 'vanilla':
|
| 37 |
+
self.model = SSN(in_channels=in_channels,
|
| 38 |
+
out_channels=out_channels,
|
| 39 |
+
mid_act=mid_act,
|
| 40 |
+
out_act=out_act,
|
| 41 |
+
resnet=resnet)
|
| 42 |
+
elif backbone == 'SSN_v1':
|
| 43 |
+
self.model = SSN_v1(in_channels=in_channels,
|
| 44 |
+
out_channels=out_channels,
|
| 45 |
+
mid_act=mid_act,
|
| 46 |
+
out_act=out_act,
|
| 47 |
+
resnet=resnet)
|
| 48 |
+
else:
|
| 49 |
+
raise NotImplementedError('{} has not implemented yet'.format(backbone))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
self.optimizer = get_optimizer(opt, self.model)
|
| 53 |
+
self.visualization = {}
|
| 54 |
+
|
| 55 |
+
self.norm_loss = norm_loss()
|
| 56 |
+
|
| 57 |
+
# inference related
|
| 58 |
+
BINs = 100
|
| 59 |
+
MAX_RAD = 20
|
| 60 |
+
self.size_interval = MAX_RAD / BINs
|
| 61 |
+
self.soft_distribution = [[np.exp(-0.2 * (i - j) ** 2) for i in np.arange(BINs)] for j in np.arange(BINs)]
|
| 62 |
+
|
| 63 |
+
def setup_input(self, x):
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
x, softness = x
|
| 69 |
+
return self.model(x, softness)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def compute_loss(self, y, pred):
|
| 73 |
+
b = y.shape[0]
|
| 74 |
+
|
| 75 |
+
total_loss = self.norm_loss.loss(y, pred)
|
| 76 |
+
|
| 77 |
+
if self.focal:
|
| 78 |
+
total_loss = torch.pow(total_loss, 3)
|
| 79 |
+
|
| 80 |
+
return total_loss
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def supervise(self, input_x, y, is_training:bool)->float:
|
| 84 |
+
optimizer = self.optimizer
|
| 85 |
+
model = self.model
|
| 86 |
+
|
| 87 |
+
x, softness = input_x['x'], input_x['softness']
|
| 88 |
+
|
| 89 |
+
optimizer.zero_grad()
|
| 90 |
+
pred = model(x, softness)
|
| 91 |
+
loss = self.compute_loss(y, pred)
|
| 92 |
+
|
| 93 |
+
if is_training:
|
| 94 |
+
loss.backward()
|
| 95 |
+
optimizer.step()
|
| 96 |
+
|
| 97 |
+
xc = x.shape[1]
|
| 98 |
+
for i in range(xc):
|
| 99 |
+
self.visualization['x{}'.format(i)] = x[:, i:i+1].detach()
|
| 100 |
+
|
| 101 |
+
self.visualization['y'] = y.detach()
|
| 102 |
+
self.visualization['pred'] = pred.detach()
|
| 103 |
+
|
| 104 |
+
return loss.item()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_visualize(self) -> OrderedDict:
|
| 108 |
+
""" Convert to visualization numpy array
|
| 109 |
+
"""
|
| 110 |
+
nrows = self.ncols
|
| 111 |
+
visualizations = self.visualization
|
| 112 |
+
ret_vis = OrderedDict()
|
| 113 |
+
|
| 114 |
+
for k, v in visualizations.items():
|
| 115 |
+
batch = v.shape[0]
|
| 116 |
+
n = min(nrows, batch)
|
| 117 |
+
|
| 118 |
+
plot_v = v[:n]
|
| 119 |
+
ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0)
|
| 120 |
+
ret_vis[k] = self.plasma(ret_vis[k])
|
| 121 |
+
|
| 122 |
+
return ret_vis
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_logs(self):
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def inference(self, x):
|
| 130 |
+
x, l, device = x['x'], x['l'], x['device']
|
| 131 |
+
|
| 132 |
+
x = torch.from_numpy(x.transpose((2,0,1))).unsqueeze(dim=0).to(device)
|
| 133 |
+
l = torch.from_numpy(np.array(self.soft_distribution[int(l/self.size_interval)]).astype(np.float32)).unsqueeze(dim=0).to(device)
|
| 134 |
+
|
| 135 |
+
pred = self.forward((x, l))
|
| 136 |
+
pred = pred[0].detach().cpu().numpy().transpose((1,2,0))
|
| 137 |
+
return pred
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def batch_inference(self, x):
|
| 141 |
+
x, l = x['x'], x['softness']
|
| 142 |
+
pred = self.forward((x, l))
|
| 143 |
+
return pred
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
""" Getter & Setter
|
| 147 |
+
"""
|
| 148 |
+
def get_models(self) -> dict:
|
| 149 |
+
return {'model': self.model}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_optimizers(self) -> dict:
|
| 153 |
+
return {'optimizer': self.optimizer}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def set_models(self, models: dict) :
|
| 157 |
+
# input test
|
| 158 |
+
if 'model' not in models.keys():
|
| 159 |
+
raise ValueError('{} not in self.model'.format('model'))
|
| 160 |
+
|
| 161 |
+
self.model = models['model']
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def set_optimizers(self, optimizer: dict):
|
| 165 |
+
self.optimizer = optimizer['optimizer']
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
####################
|
| 169 |
+
# Personal Methods #
|
| 170 |
+
####################
|
| 171 |
+
def plasma(self, x):
|
| 172 |
+
norm = mpl.colors.Normalize(vmin=0.0, vmax=1)
|
| 173 |
+
mapper = cm.ScalarMappable(norm=norm, cmap='plasma')
|
| 174 |
+
bimg = mapper.to_rgba(x[:,:,0])[:,:,:3]
|
| 175 |
+
|
| 176 |
+
return bimg
|
models/Loss/Loss.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
from torch.autograd import Variable
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
# from vgg19_loss import VGG19Loss
|
| 11 |
+
# import pytorch_ssim
|
| 12 |
+
|
| 13 |
+
from .vgg19_loss import VGG19Loss
|
| 14 |
+
from . import pytorch_ssim
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from collections import OrderedDict
|
| 17 |
+
|
| 18 |
+
class abs_loss(ABC):
|
| 19 |
+
def loss(self, gt_img, pred_img):
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class norm_loss(abs_loss):
|
| 24 |
+
def __init__(self, norm=1):
|
| 25 |
+
self.norm = norm
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def loss(self, gt_img, pred_img):
|
| 29 |
+
""" M * (I-I') """
|
| 30 |
+
b, c, h, w = gt_img.shape
|
| 31 |
+
return torch.norm(gt_img-pred_img, self.norm)/(h * w * b)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ssim_loss(abs_loss):
|
| 36 |
+
def __init__(self, window_size=11, channel=1):
|
| 37 |
+
""" Let's try mean ssim!
|
| 38 |
+
"""
|
| 39 |
+
self.channel = channel
|
| 40 |
+
self.window_size = window_size
|
| 41 |
+
self.window = self.create_mean_window(window_size, channel)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def loss(self, gt_img, pred_img):
|
| 45 |
+
b, c, h, w = gt_img.shape
|
| 46 |
+
if c != self.channel:
|
| 47 |
+
self.channel = c
|
| 48 |
+
self.window = self.create_mean_window(self.window_size, self.channel)
|
| 49 |
+
|
| 50 |
+
self.window = self.window.to(gt_img).type_as(gt_img)
|
| 51 |
+
l = 1.0 - self.ssim_compute(gt_img, pred_img)
|
| 52 |
+
return l
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def create_mean_window(self, window_size, channel):
|
| 56 |
+
window = Variable(torch.ones(channel, 1, window_size, window_size).float())
|
| 57 |
+
window = window/(window_size * window_size)
|
| 58 |
+
return window
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def ssim_compute(self, gt_img, pred_img):
|
| 62 |
+
window = self.window
|
| 63 |
+
window_size = self.window_size
|
| 64 |
+
channel = self.channel
|
| 65 |
+
|
| 66 |
+
mu1 = F.conv2d(gt_img, window, padding = window_size//2, groups = channel)
|
| 67 |
+
mu2 = F.conv2d(pred_img, window, padding = window_size//2, groups = channel)
|
| 68 |
+
|
| 69 |
+
mu1_sq = mu1.pow(2)
|
| 70 |
+
mu2_sq = mu2.pow(2)
|
| 71 |
+
mu1_mu2 = mu1*mu2
|
| 72 |
+
|
| 73 |
+
sigma1_sq = F.conv2d(gt_img*gt_img, window, padding = window_size//2, groups = channel) - mu1_sq
|
| 74 |
+
sigma2_sq = F.conv2d(pred_img*pred_img, window, padding = window_size//2, groups = channel) - mu2_sq
|
| 75 |
+
sigma12 = F.conv2d(gt_img*pred_img, window, padding = window_size//2, groups = channel) - mu1_mu2
|
| 76 |
+
|
| 77 |
+
C1 = 0.01**2
|
| 78 |
+
C2 = 0.03**2
|
| 79 |
+
|
| 80 |
+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
| 81 |
+
|
| 82 |
+
return ssim_map.mean()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class hierarchical_ssim_loss(abs_loss):
|
| 86 |
+
def __init__(self, patch_list: list):
|
| 87 |
+
self.ssim_loss_list = [pytorch_ssim.SSIM(window_size=ws) for ws in patch_list]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def loss(self, gt_img, pred_img):
|
| 91 |
+
b, c, h, w = gt_img.shape
|
| 92 |
+
total_loss = 0.0
|
| 93 |
+
for loss_func in self.ssim_loss_list:
|
| 94 |
+
total_loss += (1.0-loss_func(gt_img, pred_img))
|
| 95 |
+
|
| 96 |
+
return total_loss/b
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class vgg_loss(abs_loss):
|
| 100 |
+
def __init__(self):
|
| 101 |
+
self.vgg19_ = VGG19Loss()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def loss(self, gt_img, pred_img):
|
| 105 |
+
b, c, h, w = gt_img.shape
|
| 106 |
+
v = self.vgg19_(gt_img, pred_img, pred_img.device)
|
| 107 |
+
return v/b
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class grad_loss(abs_loss):
|
| 111 |
+
def __init__(self, k=4):
|
| 112 |
+
self.k = 4
|
| 113 |
+
|
| 114 |
+
def loss(self, disp_img, rgb_img=None):
|
| 115 |
+
""" Note, gradient loss should be weighted by an edge-aware weight
|
| 116 |
+
"""
|
| 117 |
+
b, c, h, w = disp_img.shape
|
| 118 |
+
|
| 119 |
+
grad_loss = 0.0
|
| 120 |
+
for i in range(self.k):
|
| 121 |
+
div_factor = 2 ** i
|
| 122 |
+
cur_transform = T.Resize([h // div_factor, ])
|
| 123 |
+
# cur_diff = cur_transform(diff)
|
| 124 |
+
# cur_diff_dx, cur_diff_dy = self.img_grad(cur_diff)
|
| 125 |
+
cur_disp = cur_transform(disp_img)
|
| 126 |
+
|
| 127 |
+
cur_disp_dx, cur_disp_dy = self.img_grad(cur_disp)
|
| 128 |
+
|
| 129 |
+
if rgb_img is not None:
|
| 130 |
+
cur_rgb = cur_transform(rgb_img)
|
| 131 |
+
cur_rgb_dx, cur_rgb_dy = self.img_grad(cur_rgb)
|
| 132 |
+
|
| 133 |
+
cur_rgb_dx = torch.exp(-torch.mean(torch.abs(cur_rgb_dx), dim=1, keepdims=True))
|
| 134 |
+
cur_rgb_dy = torch.exp(-torch.mean(torch.abs(cur_rgb_dy), dim=1, keepdims=True))
|
| 135 |
+
grad_loss += (torch.sum(torch.abs(cur_disp_dx) * cur_rgb_dx) + torch.sum(torch.abs(cur_disp_dy) * cur_rgb_dy)) / (h * w * self.k)
|
| 136 |
+
else:
|
| 137 |
+
grad_loss += (torch.sum(torch.abs(cur_disp_dx)) + torch.sum(torch.abs(cur_disp_dy))) / (h * w * self.k)
|
| 138 |
+
|
| 139 |
+
return grad_loss/b
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def gloss(self, gt, pred):
|
| 143 |
+
""" Loss on the gradient domain
|
| 144 |
+
"""
|
| 145 |
+
b, c, h, w = gt.shape
|
| 146 |
+
gt_dx, gt_dy = self.img_grad(gt)
|
| 147 |
+
pred_dx, pred_dy = self.img_grad(pred)
|
| 148 |
+
|
| 149 |
+
loss = (gt_dx-pred_dx) ** 2 + (gt_dy - pred_dy) ** 2
|
| 150 |
+
return loss.sum()/(b * h * w)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def laploss(self, pred):
|
| 154 |
+
b, c, h, w = pred.shape
|
| 155 |
+
lap = self.img_laplacian(pred)
|
| 156 |
+
|
| 157 |
+
return torch.abs(lap).sum()/(b * h * w)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def img_laplacian(self, img):
|
| 161 |
+
b, c, h, w = img.shape
|
| 162 |
+
laplacian = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])
|
| 163 |
+
|
| 164 |
+
laplacian_kernel = laplacian.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
|
| 165 |
+
|
| 166 |
+
lap = F.conv2d(img, laplacian_kernel, padding=1, stride=1)
|
| 167 |
+
return lap
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def img_grad(self, img):
|
| 171 |
+
""" Comptue image gradient by sobel filtering
|
| 172 |
+
img: B x C x H x W
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
b, c, h, w = img.shape
|
| 176 |
+
ysobel = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
|
| 177 |
+
xsobel = ysobel.transpose(0,1)
|
| 178 |
+
|
| 179 |
+
xsobel_kernel = xsobel.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
|
| 180 |
+
ysobel_kernel = ysobel.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
|
| 181 |
+
dx = F.conv2d(img, xsobel_kernel, padding=1, stride=1)
|
| 182 |
+
dy = F.conv2d(img, ysobel_kernel, padding=1, stride=1)
|
| 183 |
+
|
| 184 |
+
return dx, dy
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class sharp_loss(abs_loss):
|
| 189 |
+
""" Sharpness term
|
| 190 |
+
1. laplacian
|
| 191 |
+
2. image contrast
|
| 192 |
+
3. image variance
|
| 193 |
+
"""
|
| 194 |
+
def __init__(self, window_size=11, channel=1):
|
| 195 |
+
self.window_size = window_size
|
| 196 |
+
self.channel = channel
|
| 197 |
+
self.window = self.create_mean_window(window_size, self.channel)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def loss(self, gt_img, pred_img):
|
| 201 |
+
""" Note, gradient loss should be weighted by an edge-aware weight
|
| 202 |
+
"""
|
| 203 |
+
b, c, h, w = gt_img.shape
|
| 204 |
+
|
| 205 |
+
if c != self.channel:
|
| 206 |
+
self.channel = c
|
| 207 |
+
self.window = self.create_mean_window(self.window_size, self.channel)
|
| 208 |
+
|
| 209 |
+
self.window = self.window.to(gt_img).type_as(gt_img)
|
| 210 |
+
|
| 211 |
+
channel = self.channel
|
| 212 |
+
window = self.window
|
| 213 |
+
window_size = self.window_size
|
| 214 |
+
|
| 215 |
+
mu1 = F.conv2d(gt_img, window, padding = window_size//2, groups = channel) + 1e-6
|
| 216 |
+
mu2 = F.conv2d(pred_img, window, padding = window_size//2, groups = channel) + 1e-6
|
| 217 |
+
|
| 218 |
+
constrast1 = torch.absolute((gt_img - mu1)/mu1)
|
| 219 |
+
constrast2 = torch.absolute((pred_img - mu2)/mu2)
|
| 220 |
+
|
| 221 |
+
variance1 = (gt_img-mu1) ** 2
|
| 222 |
+
variance2 = (pred_img-mu2) ** 2
|
| 223 |
+
|
| 224 |
+
laplacian1 = self.img_laplacian(gt_img)
|
| 225 |
+
laplacian2 = self.img_laplacian(pred_img)
|
| 226 |
+
|
| 227 |
+
S1 = -laplacian1 - constrast1 - variance1
|
| 228 |
+
S2 = -laplacian2 - constrast2 - variance2
|
| 229 |
+
|
| 230 |
+
# import pdb; pdb.set_trace()
|
| 231 |
+
total = torch.absolute(S1-S2).mean()
|
| 232 |
+
return total
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def img_laplacian(self, img):
|
| 236 |
+
b, c, h, w = img.shape
|
| 237 |
+
laplacian = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])
|
| 238 |
+
|
| 239 |
+
laplacian_kernel = laplacian.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
|
| 240 |
+
|
| 241 |
+
lap = F.conv2d(img, laplacian_kernel, padding=1, stride=1)
|
| 242 |
+
return lap
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def create_mean_window(self, window_size, channel):
|
| 246 |
+
window = Variable(torch.ones(channel, 1, window_size, window_size).float())
|
| 247 |
+
window = window/(window_size * window_size)
|
| 248 |
+
return window
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if __name__ == '__main__':
|
| 252 |
+
a = torch.rand(3,3,128,128)
|
| 253 |
+
b = torch.rand(3,3,128,128)
|
| 254 |
+
|
| 255 |
+
ssim = ssim_loss()
|
| 256 |
+
loss = ssim.loss(a, b)
|
| 257 |
+
print(loss.shape, loss)
|
| 258 |
+
|
| 259 |
+
loss = ssim.loss(a, a)
|
| 260 |
+
print(loss.shape, loss)
|
| 261 |
+
|
| 262 |
+
loss = ssim.loss(b, b)
|
| 263 |
+
print(loss.shape, loss)
|
| 264 |
+
|
| 265 |
+
grad = grad_loss()
|
| 266 |
+
loss = grad.loss(a, [b, b])
|
| 267 |
+
print(loss.shape, loss)
|
| 268 |
+
|
| 269 |
+
sharp = sharp_loss()
|
| 270 |
+
loss = sharp.loss(a, b)
|
| 271 |
+
print(loss.shape, loss)
|
models/Loss/__init__.py
ADDED
|
File without changes
|
models/Loss/__pycache__/Loss.cpython-39.pyc
ADDED
|
Binary file (8.36 kB). View file
|
|
|
models/Loss/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|
models/Loss/__pycache__/vgg19_loss.cpython-39.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
models/Loss/pytorch_ssim/__init__.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
import numpy as np
|
| 5 |
+
from math import exp
|
| 6 |
+
|
| 7 |
+
def gaussian(window_size, sigma):
|
| 8 |
+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
| 9 |
+
return gauss/gauss.sum()
|
| 10 |
+
|
| 11 |
+
def create_window(window_size, channel):
|
| 12 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
| 13 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
| 14 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
| 15 |
+
return window
|
| 16 |
+
|
| 17 |
+
def _ssim(img1, img2, window, window_size, channel, size_average = True):
|
| 18 |
+
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
|
| 19 |
+
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
|
| 20 |
+
|
| 21 |
+
mu1_sq = mu1.pow(2)
|
| 22 |
+
mu2_sq = mu2.pow(2)
|
| 23 |
+
mu1_mu2 = mu1*mu2
|
| 24 |
+
|
| 25 |
+
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
|
| 26 |
+
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
|
| 27 |
+
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
|
| 28 |
+
|
| 29 |
+
C1 = 0.01**2
|
| 30 |
+
C2 = 0.03**2
|
| 31 |
+
|
| 32 |
+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
| 33 |
+
|
| 34 |
+
if size_average:
|
| 35 |
+
return ssim_map.mean()
|
| 36 |
+
else:
|
| 37 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
| 38 |
+
|
| 39 |
+
class SSIM(torch.nn.Module):
|
| 40 |
+
def __init__(self, window_size = 11, size_average = True):
|
| 41 |
+
super(SSIM, self).__init__()
|
| 42 |
+
self.window_size = window_size
|
| 43 |
+
self.size_average = size_average
|
| 44 |
+
self.channel = 1
|
| 45 |
+
self.window = create_window(window_size, self.channel)
|
| 46 |
+
|
| 47 |
+
def forward(self, img1, img2):
|
| 48 |
+
(_, channel, _, _) = img1.size()
|
| 49 |
+
|
| 50 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
| 51 |
+
window = self.window
|
| 52 |
+
else:
|
| 53 |
+
window = create_window(self.window_size, channel)
|
| 54 |
+
|
| 55 |
+
if img1.is_cuda:
|
| 56 |
+
window = window.cuda(img1.get_device())
|
| 57 |
+
window = window.type_as(img1)
|
| 58 |
+
|
| 59 |
+
self.window = window
|
| 60 |
+
self.channel = channel
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
| 64 |
+
|
| 65 |
+
def ssim(img1, img2, window_size = 11, size_average = True):
|
| 66 |
+
(_, channel, _, _) = img1.size()
|
| 67 |
+
window = create_window(window_size, channel)
|
| 68 |
+
|
| 69 |
+
if img1.is_cuda:
|
| 70 |
+
window = window.cuda(img1.get_device())
|
| 71 |
+
window = window.type_as(img1)
|
| 72 |
+
|
| 73 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
models/Loss/pytorch_ssim/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (2.65 kB). View file
|
|
|
models/Loss/vgg19_loss.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision
|
| 4 |
+
|
| 5 |
+
class FeatureExtractor(nn.Module):
|
| 6 |
+
def __init__(self, cnn, feature_layer=11):
|
| 7 |
+
super(FeatureExtractor, self).__init__()
|
| 8 |
+
self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer + 1)])
|
| 9 |
+
|
| 10 |
+
def normalize(self, tensors, mean, std):
|
| 11 |
+
if not torch.is_tensor(tensors):
|
| 12 |
+
raise TypeError('tensor is not a torch image.')
|
| 13 |
+
for tensor in tensors:
|
| 14 |
+
for t, m, s in zip(tensor, mean, std):
|
| 15 |
+
t.sub_(m).div_(s)
|
| 16 |
+
return tensors
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
# it image is gray scale then make it to 3 channel
|
| 20 |
+
if x.size()[1] == 1:
|
| 21 |
+
x = x.expand(-1, 3, -1, -1)
|
| 22 |
+
|
| 23 |
+
# [-1: 1] image to [0:1] image---------------------------------------------------(1)
|
| 24 |
+
x = (x + 1) * 0.5
|
| 25 |
+
|
| 26 |
+
# https://pytorch.org/docs/stable/torchvision/models.html
|
| 27 |
+
x.data = self.normalize(x.data, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 28 |
+
return self.features(x)
|
| 29 |
+
|
| 30 |
+
# Feature extracting using vgg19
|
| 31 |
+
vgg19 = torchvision.models.vgg19(pretrained=True)
|
| 32 |
+
feature_extractor = FeatureExtractor(vgg19, feature_layer=35)
|
| 33 |
+
feature_extractor.eval()
|
| 34 |
+
|
| 35 |
+
class VGG19Loss(object):
|
| 36 |
+
def __init__(self):
|
| 37 |
+
global feature_extractor
|
| 38 |
+
self.initialized = False
|
| 39 |
+
self.feature_extractor = feature_extractor
|
| 40 |
+
self.MSE = nn.MSELoss()
|
| 41 |
+
|
| 42 |
+
def __call__(self, output, target, device):
|
| 43 |
+
if self.initialized == False:
|
| 44 |
+
self.feature_extractor = self.feature_extractor.to(device)
|
| 45 |
+
self.MSE = self.MSE.to(device)
|
| 46 |
+
self.initialized = True
|
| 47 |
+
|
| 48 |
+
# [-1: 1] image to [0:1] image---------------------------------------------------(2)
|
| 49 |
+
output = (output + 1) * 0.5
|
| 50 |
+
target = (target + 1) * 0.5
|
| 51 |
+
|
| 52 |
+
output = self.feature_extractor(output)
|
| 53 |
+
target = self.feature_extractor(target).data
|
| 54 |
+
return self.MSE(output, target)
|
models/SSN.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision import utils
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from .abs_model import abs_model
|
| 9 |
+
from .Loss.Loss import norm_loss
|
| 10 |
+
from .blocks import *
|
| 11 |
+
from .SSN_Model import SSN_Model
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SSN(abs_model):
|
| 15 |
+
def __init__(self, opt):
|
| 16 |
+
mid_act = opt['model']['mid_act']
|
| 17 |
+
out_act = opt['model']['out_act']
|
| 18 |
+
in_channels = opt['model']['in_channels']
|
| 19 |
+
out_channels = opt['model']['out_channels']
|
| 20 |
+
self.ncols = opt['hyper_params']['n_cols']
|
| 21 |
+
|
| 22 |
+
self.model = SSN_Model(in_channels=in_channels, out_channels=out_channels, mid_act=mid_act, out_act=out_act)
|
| 23 |
+
self.optimizer = get_optimizer(opt, self.model)
|
| 24 |
+
self.visualization = {}
|
| 25 |
+
|
| 26 |
+
self.norm_loss_ = norm_loss(norm=1)
|
| 27 |
+
|
| 28 |
+
def setup_input(self, x):
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
keys = ['mask', 'ibl']
|
| 34 |
+
|
| 35 |
+
for k in keys:
|
| 36 |
+
assert k in x.keys(), '{} not in input'.format(k)
|
| 37 |
+
|
| 38 |
+
mask = x['mask']
|
| 39 |
+
ibl = x['ibl']
|
| 40 |
+
|
| 41 |
+
return self.model(mask, ibl)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def compute_loss(self, y, pred):
|
| 45 |
+
total_loss = self.norm_loss_.loss(y, pred)
|
| 46 |
+
return total_loss
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def supervise(self, input_x, y, is_training:bool)->float:
|
| 50 |
+
optimizer = self.optimizer
|
| 51 |
+
model = self.model
|
| 52 |
+
|
| 53 |
+
optimizer.zero_grad()
|
| 54 |
+
pred = self.forward(input_x)
|
| 55 |
+
loss = self.compute_loss(y, pred)
|
| 56 |
+
|
| 57 |
+
# logging.info('Pred/Target: {}, {}/{}, {}'.format(pred.min().item(), pred.max().item(), y.min().item(), y.max().item()))
|
| 58 |
+
|
| 59 |
+
if is_training:
|
| 60 |
+
loss.backward()
|
| 61 |
+
optimizer.step()
|
| 62 |
+
|
| 63 |
+
self.visualization['mask'] = input_x['mask'].detach()
|
| 64 |
+
self.visualization['ibl'] = input_x['ibl'].detach()
|
| 65 |
+
self.visualization['y'] = y.detach()
|
| 66 |
+
self.visualization['pred'] = pred.detach()
|
| 67 |
+
|
| 68 |
+
return loss.item()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_visualize(self) -> OrderedDict:
|
| 72 |
+
""" Convert to visualization numpy array
|
| 73 |
+
"""
|
| 74 |
+
nrows = self.ncols
|
| 75 |
+
visualizations = self.visualization
|
| 76 |
+
ret_vis = OrderedDict()
|
| 77 |
+
|
| 78 |
+
for k, v in visualizations.items():
|
| 79 |
+
batch = v.shape[0]
|
| 80 |
+
n = min(nrows, batch)
|
| 81 |
+
|
| 82 |
+
plot_v = v[:n]
|
| 83 |
+
plot_v = (plot_v - plot_v.min())/(plot_v.max() - plot_v.min())
|
| 84 |
+
ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0)
|
| 85 |
+
|
| 86 |
+
return ret_vis
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_logs(self):
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def inference(self, x):
|
| 94 |
+
keys = ['mask', 'ibl']
|
| 95 |
+
for k in keys:
|
| 96 |
+
assert k in x.keys(), '{} not in input'.format(k)
|
| 97 |
+
assert len(x[k].shape) == 2, '{} should be 2D tensor'.format(k)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 101 |
+
|
| 102 |
+
mask = torch.tensor(x['mask'])[None, None, ...].float().to(device)
|
| 103 |
+
ibl = torch.tensor(x['ibl'])[None, None, ...].float().to(device)
|
| 104 |
+
|
| 105 |
+
input_x = {'mask': mask, 'ibl': ibl}
|
| 106 |
+
pred = self.forward(input_x)
|
| 107 |
+
|
| 108 |
+
pred = np.clip(pred[0, 0].detach().cpu().numpy() / 30.0, 0.0, 1.0)
|
| 109 |
+
return pred
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def batch_inference(self, x):
|
| 114 |
+
# TODO
|
| 115 |
+
pass
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
""" Getter & Setter
|
| 119 |
+
"""
|
| 120 |
+
def get_models(self) -> dict:
|
| 121 |
+
return {'model': self.model}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_optimizers(self) -> dict:
|
| 125 |
+
return {'optimizer': self.optimizer}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def set_models(self, models: dict) :
|
| 129 |
+
# input test
|
| 130 |
+
if 'model' not in models.keys():
|
| 131 |
+
raise ValueError('{} not in self.model'.format('model'))
|
| 132 |
+
|
| 133 |
+
self.model = models['model']
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def set_optimizers(self, optimizer: dict):
|
| 137 |
+
self.optimizer = optimizer['optimizer']
|
| 138 |
+
|
| 139 |
+
####################
|
| 140 |
+
# Personal Methods #
|
| 141 |
+
####################
|
| 142 |
+
|
| 143 |
+
|
models/SSN_Model.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
def weights_init(init_type='gaussian', std=0.02):
|
| 8 |
+
def init_fun(m):
|
| 9 |
+
classname = m.__class__.__name__
|
| 10 |
+
if (classname.find('Conv') == 0 or classname.find(
|
| 11 |
+
'Linear') == 0) and hasattr(m, 'weight'):
|
| 12 |
+
if init_type == 'gaussian':
|
| 13 |
+
nn.init.normal_(m.weight, 0.0, std)
|
| 14 |
+
elif init_type == 'xavier':
|
| 15 |
+
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
|
| 16 |
+
elif init_type == 'kaiming':
|
| 17 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
| 18 |
+
elif init_type == 'orthogonal':
|
| 19 |
+
nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
|
| 20 |
+
elif init_type == 'default':
|
| 21 |
+
pass
|
| 22 |
+
else:
|
| 23 |
+
assert 0, "Unsupported initialization: {}".format(init_type)
|
| 24 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 25 |
+
nn.init.constant_(m.bias, 0.0)
|
| 26 |
+
|
| 27 |
+
return init_fun
|
| 28 |
+
|
| 29 |
+
def freeze(module):
|
| 30 |
+
for param in module.parameters():
|
| 31 |
+
param.requires_grad = False
|
| 32 |
+
|
| 33 |
+
def unfreeze(module):
|
| 34 |
+
for param in module.parameters():
|
| 35 |
+
param.requires_grad = True
|
| 36 |
+
|
| 37 |
+
def get_optimizer(opt, model):
|
| 38 |
+
lr = float(opt['hyper_params']['lr'])
|
| 39 |
+
beta1 = float(opt['model']['beta1'])
|
| 40 |
+
weight_decay = float(opt['model']['weight_decay'])
|
| 41 |
+
opt_name = opt['model']['optimizer']
|
| 42 |
+
|
| 43 |
+
optim_params = []
|
| 44 |
+
# weight decay
|
| 45 |
+
for key, value in model.named_parameters():
|
| 46 |
+
if not value.requires_grad:
|
| 47 |
+
continue # frozen weights
|
| 48 |
+
|
| 49 |
+
if key[-4:] == 'bias':
|
| 50 |
+
optim_params += [{'params': value, 'weight_decay': 0.0}]
|
| 51 |
+
else:
|
| 52 |
+
optim_params += [{'params': value,
|
| 53 |
+
'weight_decay': weight_decay}]
|
| 54 |
+
|
| 55 |
+
if opt_name == 'Adam':
|
| 56 |
+
return optim.Adam(optim_params,
|
| 57 |
+
lr=lr,
|
| 58 |
+
betas=(beta1, 0.999),
|
| 59 |
+
eps=1e-5)
|
| 60 |
+
else:
|
| 61 |
+
err = '{} not implemented yet'.format(opt_name)
|
| 62 |
+
logging.error(err)
|
| 63 |
+
raise NotImplementedError(err)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_activation(activation):
|
| 67 |
+
if activation is None:
|
| 68 |
+
return nn.Identity()
|
| 69 |
+
|
| 70 |
+
act_func = {
|
| 71 |
+
'relu':nn.ReLU(),
|
| 72 |
+
'sigmoid':nn.Sigmoid(),
|
| 73 |
+
'tanh':nn.Tanh(),
|
| 74 |
+
'prelu':nn.PReLU(),
|
| 75 |
+
'leaky':nn.LeakyReLU(0.2),
|
| 76 |
+
'gelu':nn.GELU(),
|
| 77 |
+
}
|
| 78 |
+
if activation not in act_func.keys():
|
| 79 |
+
logging.error("activation {} is not implemented yet".format(activation))
|
| 80 |
+
assert False
|
| 81 |
+
|
| 82 |
+
return act_func[activation]
|
| 83 |
+
|
| 84 |
+
def get_norm(out_channels, norm_type='Instance'):
|
| 85 |
+
norm_set = ['Instance', 'Batch', 'Group']
|
| 86 |
+
if norm_type not in norm_set:
|
| 87 |
+
err = "Normalization {} has not been implemented yet"
|
| 88 |
+
logging.error(err)
|
| 89 |
+
raise ValueError(err)
|
| 90 |
+
|
| 91 |
+
if norm_type == 'Instance':
|
| 92 |
+
return nn.InstanceNorm2d(out_channels, affine=True)
|
| 93 |
+
|
| 94 |
+
if norm_type == 'Batch':
|
| 95 |
+
return nn.BatchNorm2d(out_channels)
|
| 96 |
+
|
| 97 |
+
if norm_type == 'Group':
|
| 98 |
+
if out_channels >= 32:
|
| 99 |
+
groups = 32
|
| 100 |
+
else:
|
| 101 |
+
groups = 1
|
| 102 |
+
|
| 103 |
+
return nn.GroupNorm(groups, out_channels)
|
| 104 |
+
|
| 105 |
+
else:
|
| 106 |
+
raise NotImplementedError('{} has not implemented yet'.format(norm_type))
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_layer_info(out_channels, activation_func='relu'):
|
| 111 |
+
activation = get_activation(activation_func)
|
| 112 |
+
norm_layer = get_norm(out_channels, 'Group')
|
| 113 |
+
return norm_layer, activation
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class Conv(nn.Module):
|
| 117 |
+
""" (convolution => [BN] => ReLU) """
|
| 118 |
+
def __init__(self,
|
| 119 |
+
in_channels,
|
| 120 |
+
out_channels,
|
| 121 |
+
kernel_size=3,
|
| 122 |
+
stride=1,
|
| 123 |
+
padding=1,
|
| 124 |
+
bias=True,
|
| 125 |
+
activation='leaky',
|
| 126 |
+
resnet=True):
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
norm_layer, act_func = get_layer_info(out_channels,activation)
|
| 130 |
+
|
| 131 |
+
if resnet and in_channels == out_channels:
|
| 132 |
+
self.resnet = True
|
| 133 |
+
else:
|
| 134 |
+
self.resnet = False
|
| 135 |
+
|
| 136 |
+
self.conv = nn.Sequential(
|
| 137 |
+
nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=kernel_size, padding=padding, bias=bias),
|
| 138 |
+
norm_layer,
|
| 139 |
+
act_func)
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
res = self.conv(x)
|
| 143 |
+
|
| 144 |
+
if self.resnet:
|
| 145 |
+
res = res + x
|
| 146 |
+
|
| 147 |
+
return res
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class Up(nn.Module):
|
| 152 |
+
""" Upscaling then conv """
|
| 153 |
+
|
| 154 |
+
def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
|
| 155 |
+
super().__init__()
|
| 156 |
+
|
| 157 |
+
self.up_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 158 |
+
self.up = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
|
| 159 |
+
|
| 160 |
+
def forward(self, x):
|
| 161 |
+
x = self.up_layer(x)
|
| 162 |
+
return self.up(x)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class DConv(nn.Module):
|
| 167 |
+
""" Double Conv Layer
|
| 168 |
+
"""
|
| 169 |
+
def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
|
| 170 |
+
super().__init__()
|
| 171 |
+
|
| 172 |
+
self.conv1 = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
|
| 173 |
+
self.conv2 = Conv(out_channels, out_channels, activation=activation, resnet=resnet)
|
| 174 |
+
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
return self.conv2(self.conv1(x))
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class Encoder(nn.Module):
|
| 180 |
+
def __init__(self, in_channels=3, mid_act='leaky', resnet=True):
|
| 181 |
+
super(Encoder, self).__init__()
|
| 182 |
+
self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
|
| 183 |
+
self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
|
| 184 |
+
self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
|
| 185 |
+
self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
|
| 186 |
+
self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
|
| 187 |
+
self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
|
| 188 |
+
self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
|
| 189 |
+
self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
|
| 190 |
+
self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 191 |
+
self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 192 |
+
self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
x1 = self.in_conv(x) # 32 x 256 x 256
|
| 197 |
+
x1 = torch.cat((x, x1), dim=1)
|
| 198 |
+
|
| 199 |
+
x2 = self.down_32_64(x1)
|
| 200 |
+
x3 = self.down_64_64_1(x2)
|
| 201 |
+
|
| 202 |
+
x4 = self.down_64_128(x3)
|
| 203 |
+
x5 = self.down_128_128_1(x4)
|
| 204 |
+
|
| 205 |
+
x6 = self.down_128_256(x5)
|
| 206 |
+
x7 = self.down_256_256_1(x6)
|
| 207 |
+
|
| 208 |
+
x8 = self.down_256_512(x7)
|
| 209 |
+
x9 = self.down_512_512_1(x8)
|
| 210 |
+
x10 = self.down_512_512_2(x9)
|
| 211 |
+
x11 = self.down_512_512_3(x10)
|
| 212 |
+
|
| 213 |
+
return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class Decoder(nn.Module):
|
| 217 |
+
""" Up Stream Sequence """
|
| 218 |
+
|
| 219 |
+
def __init__(self,
|
| 220 |
+
out_channels=3,
|
| 221 |
+
mid_act='relu',
|
| 222 |
+
out_act='sigmoid',
|
| 223 |
+
resnet = True):
|
| 224 |
+
|
| 225 |
+
super(Decoder, self).__init__()
|
| 226 |
+
|
| 227 |
+
input_channel = 512
|
| 228 |
+
fea_dim = 100
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, resnet=resnet)
|
| 232 |
+
self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
|
| 233 |
+
self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
|
| 234 |
+
|
| 235 |
+
self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
|
| 236 |
+
self.up_32_32_1 = Conv(512, 256, activation=mid_act, resnet=resnet)
|
| 237 |
+
|
| 238 |
+
self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
|
| 239 |
+
self.up_64_64_1 = Conv(256, 128, activation=mid_act, resnet=resnet)
|
| 240 |
+
|
| 241 |
+
self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
|
| 242 |
+
self.up_128_128_1 = Conv(128, 64, activation=mid_act, resnet=resnet)
|
| 243 |
+
|
| 244 |
+
self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
|
| 245 |
+
self.out_conv = Conv(64, out_channels, activation=out_act)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def forward(self, x, ibl):
|
| 249 |
+
x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
|
| 250 |
+
|
| 251 |
+
h,w = x10.shape[2:]
|
| 252 |
+
y = ibl.view(-1, 512, 1, 1).repeat(1, 1, h, w)
|
| 253 |
+
|
| 254 |
+
y = self.up_16_16_1(y) # 256 x 16 x 16
|
| 255 |
+
|
| 256 |
+
y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
|
| 257 |
+
y = self.up_16_16_2(y) # 512 x 16 x 16
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
|
| 261 |
+
y = self.up_16_16_3(y) # 512 x 16 x 16
|
| 262 |
+
|
| 263 |
+
y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
|
| 264 |
+
y = self.up_16_32(y) # 256 x 32 x 32
|
| 265 |
+
|
| 266 |
+
y = torch.cat((x7, y), dim=1)
|
| 267 |
+
y = self.up_32_32_1(y) # 256 x 32 x 32
|
| 268 |
+
|
| 269 |
+
y = torch.cat((x6, y), dim=1)
|
| 270 |
+
y = self.up_32_64(y)
|
| 271 |
+
|
| 272 |
+
y = torch.cat((x5, y), dim=1)
|
| 273 |
+
y = self.up_64_64_1(y) # 128 x 64 x 64
|
| 274 |
+
|
| 275 |
+
y = torch.cat((x4, y), dim=1)
|
| 276 |
+
y = self.up_64_128(y)
|
| 277 |
+
|
| 278 |
+
y = torch.cat((x3, y), dim=1)
|
| 279 |
+
y = self.up_128_128_1(y) # 64 x 128 x 128
|
| 280 |
+
|
| 281 |
+
y = torch.cat((x2, y), dim=1)
|
| 282 |
+
y = self.up_128_256(y) # 32 x 256 x 256
|
| 283 |
+
|
| 284 |
+
y = torch.cat((x1, y), dim=1)
|
| 285 |
+
y = self.out_conv(y) # 3 x 256 x 256
|
| 286 |
+
|
| 287 |
+
return y
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class SSN_Model(nn.Module):
|
| 291 |
+
""" Implementation of Relighting Net """
|
| 292 |
+
|
| 293 |
+
def __init__(self,
|
| 294 |
+
in_channels=3,
|
| 295 |
+
out_channels=3,
|
| 296 |
+
mid_act='leaky',
|
| 297 |
+
out_act='sigmoid',
|
| 298 |
+
resnet=True):
|
| 299 |
+
super(SSN_Model, self).__init__()
|
| 300 |
+
|
| 301 |
+
self.out_act = out_act
|
| 302 |
+
|
| 303 |
+
self.encoder = Encoder(in_channels, mid_act=mid_act, resnet=resnet)
|
| 304 |
+
self.decoder = Decoder(out_channels, mid_act=mid_act, out_act=out_act, resnet=resnet)
|
| 305 |
+
|
| 306 |
+
# init weights
|
| 307 |
+
init_func = weights_init('gaussian', std=1e-3)
|
| 308 |
+
self.encoder.apply(init_func)
|
| 309 |
+
self.decoder.apply(init_func)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def forward(self, x, ibl):
|
| 313 |
+
"""
|
| 314 |
+
Input is (source image, target light, source light, )
|
| 315 |
+
Output is: predicted new image, predicted source light, self-supervision image
|
| 316 |
+
"""
|
| 317 |
+
latent = self.encoder(x)
|
| 318 |
+
pred = self.decoder(latent, ibl)
|
| 319 |
+
|
| 320 |
+
if self.out_act == 'sigmoid':
|
| 321 |
+
pred = pred * 30.0
|
| 322 |
+
|
| 323 |
+
return pred
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
if __name__ == '__main__':
|
| 327 |
+
x = torch.randn(5,1,256,256)
|
| 328 |
+
ibl = torch.randn(5, 1, 32, 16)
|
| 329 |
+
model = SSN_Model(1,1)
|
| 330 |
+
|
| 331 |
+
y = model(x, ibl)
|
| 332 |
+
|
| 333 |
+
print('Output: ', y.shape)
|
models/SSN_v1.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
def get_activation(activation_func):
|
| 8 |
+
act_func = {
|
| 9 |
+
"relu":nn.ReLU(),
|
| 10 |
+
"sigmoid":nn.Sigmoid(),
|
| 11 |
+
"prelu":nn.PReLU(num_parameters=1),
|
| 12 |
+
"leaky_relu": nn.LeakyReLU(negative_slope=0.2, inplace=False),
|
| 13 |
+
"gelu":nn.GELU()
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
if activation_func is None:
|
| 17 |
+
return nn.Identity()
|
| 18 |
+
|
| 19 |
+
if activation_func not in act_func.keys():
|
| 20 |
+
raise ValueError("activation function({}) is not found".format(activation_func))
|
| 21 |
+
|
| 22 |
+
activation = act_func[activation_func]
|
| 23 |
+
return activation
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_layer_info(out_channels, activation_func='relu'):
|
| 27 |
+
#act_func = {"relu":nn.ReLU(), "sigmoid":nn.Sigmoid(), "prelu":nn.PReLU(num_parameters=out_channels)}
|
| 28 |
+
|
| 29 |
+
# norm_layer = nn.BatchNorm2d(out_channels, momentum=0.9)
|
| 30 |
+
if out_channels >= 32:
|
| 31 |
+
groups = 32
|
| 32 |
+
else:
|
| 33 |
+
groups = 1
|
| 34 |
+
|
| 35 |
+
norm_layer = nn.GroupNorm(groups, out_channels)
|
| 36 |
+
activation = get_activation(activation_func)
|
| 37 |
+
return norm_layer, activation
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Conv(nn.Module):
|
| 41 |
+
""" (convolution => [BN] => ReLU) """
|
| 42 |
+
def __init__(self,
|
| 43 |
+
in_channels,
|
| 44 |
+
out_channels,
|
| 45 |
+
kernel_size=3,
|
| 46 |
+
stride=1,
|
| 47 |
+
padding=1,
|
| 48 |
+
bias=True,
|
| 49 |
+
activation='leaky',
|
| 50 |
+
style=False,
|
| 51 |
+
resnet=True):
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
self.style = style
|
| 55 |
+
norm_layer, act_func = get_layer_info(in_channels, activation)
|
| 56 |
+
|
| 57 |
+
if resnet and in_channels == out_channels:
|
| 58 |
+
self.resnet = True
|
| 59 |
+
else:
|
| 60 |
+
self.resnet = False
|
| 61 |
+
|
| 62 |
+
if style:
|
| 63 |
+
self.styleconv = Conv2DMod(in_channels, out_channels, kernel_size)
|
| 64 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
| 65 |
+
else:
|
| 66 |
+
self.norm = norm_layer
|
| 67 |
+
self.conv = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=kernel_size, padding=padding, bias=bias)
|
| 68 |
+
self.act = act_func
|
| 69 |
+
|
| 70 |
+
def forward(self, x, style_fea=None):
|
| 71 |
+
if self.style:
|
| 72 |
+
res = self.styleconv(x, style_fea)
|
| 73 |
+
res = self.relu(res)
|
| 74 |
+
else:
|
| 75 |
+
h = self.conv(self.act(self.norm(x)))
|
| 76 |
+
if self.resnet:
|
| 77 |
+
res = h + x
|
| 78 |
+
else:
|
| 79 |
+
res = h
|
| 80 |
+
|
| 81 |
+
return res
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Conv2DMod(nn.Module):
|
| 85 |
+
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps=1e-8, **kwargs):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.filters = out_chan
|
| 88 |
+
self.demod = demod
|
| 89 |
+
self.kernel = kernel
|
| 90 |
+
self.stride = stride
|
| 91 |
+
self.dilation = dilation
|
| 92 |
+
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
|
| 93 |
+
self.eps = eps
|
| 94 |
+
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
| 95 |
+
|
| 96 |
+
def _get_same_padding(self, size, kernel, dilation, stride):
|
| 97 |
+
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
|
| 98 |
+
|
| 99 |
+
def forward(self, x, y):
|
| 100 |
+
b, c, h, w = x.shape
|
| 101 |
+
|
| 102 |
+
w1 = y[:, None, :, None, None]
|
| 103 |
+
w2 = self.weight[None, :, :, :, :]
|
| 104 |
+
weights = w2 * (w1 + 1)
|
| 105 |
+
|
| 106 |
+
if self.demod:
|
| 107 |
+
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
|
| 108 |
+
weights = weights * d
|
| 109 |
+
|
| 110 |
+
x = x.reshape(1, -1, h, w)
|
| 111 |
+
|
| 112 |
+
_, _, *ws = weights.shape
|
| 113 |
+
weights = weights.reshape(b * self.filters, *ws)
|
| 114 |
+
|
| 115 |
+
padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
|
| 116 |
+
x = F.conv2d(x, weights, padding=padding, groups=b)
|
| 117 |
+
|
| 118 |
+
x = x.reshape(-1, self.filters, h, w)
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Up(nn.Module):
|
| 123 |
+
""" Upscaling then conv """
|
| 124 |
+
|
| 125 |
+
def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.up_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 128 |
+
self.up = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
x = self.up_layer(x)
|
| 132 |
+
return self.up(x)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class DConv(nn.Module):
|
| 137 |
+
""" Double Conv Layer
|
| 138 |
+
"""
|
| 139 |
+
def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
|
| 140 |
+
super().__init__()
|
| 141 |
+
|
| 142 |
+
self.conv1 = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
|
| 143 |
+
self.conv2 = Conv(out_channels, out_channels, activation=activation, resnet=resnet)
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
return self.conv2(self.conv1(x))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class Encoder(nn.Module):
|
| 150 |
+
def __init__(self, in_channels=3, mid_act='leaky', resnet=True):
|
| 151 |
+
super(Encoder, self).__init__()
|
| 152 |
+
self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
|
| 153 |
+
self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
|
| 154 |
+
self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
|
| 155 |
+
self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
|
| 156 |
+
self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
|
| 157 |
+
self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
|
| 158 |
+
self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
|
| 159 |
+
self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
|
| 160 |
+
self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 161 |
+
self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 162 |
+
self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def forward(self, x):
|
| 166 |
+
x1 = self.in_conv(x) # 32 x 256 x 256
|
| 167 |
+
x1 = torch.cat((x, x1), dim=1)
|
| 168 |
+
|
| 169 |
+
x2 = self.down_32_64(x1)
|
| 170 |
+
x3 = self.down_64_64_1(x2)
|
| 171 |
+
|
| 172 |
+
x4 = self.down_64_128(x3)
|
| 173 |
+
x5 = self.down_128_128_1(x4)
|
| 174 |
+
|
| 175 |
+
x6 = self.down_128_256(x5)
|
| 176 |
+
x7 = self.down_256_256_1(x6)
|
| 177 |
+
|
| 178 |
+
x8 = self.down_256_512(x7)
|
| 179 |
+
x9 = self.down_512_512_1(x8)
|
| 180 |
+
x10 = self.down_512_512_2(x9)
|
| 181 |
+
x11 = self.down_512_512_3(x10)
|
| 182 |
+
|
| 183 |
+
return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class Decoder(nn.Module):
|
| 187 |
+
def __init__(self,
|
| 188 |
+
out_channels=3,
|
| 189 |
+
mid_act='relu',
|
| 190 |
+
out_act='sigmoid',
|
| 191 |
+
resnet = True):
|
| 192 |
+
|
| 193 |
+
super(Decoder, self).__init__()
|
| 194 |
+
|
| 195 |
+
input_channel = 512
|
| 196 |
+
fea_dim = 100
|
| 197 |
+
|
| 198 |
+
self.to_style1 = nn.Linear(in_features=fea_dim, out_features=input_channel)
|
| 199 |
+
|
| 200 |
+
self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, resnet=resnet)
|
| 201 |
+
self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
|
| 202 |
+
self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
|
| 203 |
+
|
| 204 |
+
self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
|
| 205 |
+
self.up_32_32_1 = Conv(512, 256, activation=mid_act, resnet=resnet)
|
| 206 |
+
|
| 207 |
+
self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
|
| 208 |
+
self.up_64_64_1 = Conv(256, 128, activation=mid_act, resnet=resnet)
|
| 209 |
+
|
| 210 |
+
self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
|
| 211 |
+
self.up_128_128_1 = Conv(128, 64, activation=mid_act, resnet=resnet)
|
| 212 |
+
|
| 213 |
+
self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
|
| 214 |
+
self.out_conv = Conv(64, out_channels, activation=mid_act)
|
| 215 |
+
|
| 216 |
+
self.out_act = get_activation(out_act)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def forward(self, x):
|
| 220 |
+
x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
|
| 221 |
+
|
| 222 |
+
y = self.up_16_16_1(x11)
|
| 223 |
+
|
| 224 |
+
y = torch.cat((x10, y), dim=1)
|
| 225 |
+
y = self.up_16_16_2(y)
|
| 226 |
+
|
| 227 |
+
y = torch.cat((x9, y), dim=1)
|
| 228 |
+
y = self.up_16_16_3(y)
|
| 229 |
+
|
| 230 |
+
y = torch.cat((x8, y), dim=1)
|
| 231 |
+
y = self.up_16_32(y)
|
| 232 |
+
|
| 233 |
+
y = torch.cat((x7, y), dim=1)
|
| 234 |
+
y = self.up_32_32_1(y)
|
| 235 |
+
|
| 236 |
+
y = torch.cat((x6, y), dim=1)
|
| 237 |
+
y = self.up_32_64(y)
|
| 238 |
+
|
| 239 |
+
y = torch.cat((x5, y), dim=1)
|
| 240 |
+
y = self.up_64_64_1(y) # 128 x 64 x 64
|
| 241 |
+
|
| 242 |
+
y = torch.cat((x4, y), dim=1)
|
| 243 |
+
y = self.up_64_128(y)
|
| 244 |
+
|
| 245 |
+
y = torch.cat((x3, y), dim=1)
|
| 246 |
+
y = self.up_128_128_1(y) # 64 x 128 x 128
|
| 247 |
+
|
| 248 |
+
y = torch.cat((x2, y), dim=1)
|
| 249 |
+
y = self.up_128_256(y) # 32 x 256 x 256
|
| 250 |
+
|
| 251 |
+
y = torch.cat((x1, y), dim=1)
|
| 252 |
+
y = self.out_conv(y) # 3 x 256 x 256
|
| 253 |
+
y = self.out_act(y)
|
| 254 |
+
|
| 255 |
+
return y
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class SSN_v1(nn.Module):
|
| 259 |
+
""" Implementation of Relighting Net """
|
| 260 |
+
|
| 261 |
+
def __init__(self,
|
| 262 |
+
in_channels=3,
|
| 263 |
+
out_channels=3,
|
| 264 |
+
mid_act='leaky',
|
| 265 |
+
out_act='sigmoid',
|
| 266 |
+
resnet=True):
|
| 267 |
+
super(SSN_v1, self).__init__()
|
| 268 |
+
self.encoder = Encoder(in_channels, mid_act=mid_act, resnet=resnet)
|
| 269 |
+
self.decoder = Decoder(out_channels, mid_act=mid_act, out_act=out_act, resnet=resnet)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def forward(self, x, softness):
|
| 273 |
+
"""
|
| 274 |
+
Input is (source image, target light, source light, )
|
| 275 |
+
Output is: predicted new image, predicted source light, self-supervision image
|
| 276 |
+
"""
|
| 277 |
+
latent = self.encoder(x)
|
| 278 |
+
pred = self.decoder(latent)
|
| 279 |
+
|
| 280 |
+
return pred
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
if __name__ == '__main__':
|
| 284 |
+
test_input = torch.randn(5, 1, 256, 256)
|
| 285 |
+
style = torch.randn(5, 100)
|
| 286 |
+
|
| 287 |
+
model = SSN_v1(1, 1, mid_act='gelu', out_act='gelu', resnet=True)
|
| 288 |
+
test_out = model(test_input, style)
|
| 289 |
+
|
| 290 |
+
print('Ouptut shape: ', test_out.shape)
|
models/Sparse_PH.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision import utils
|
| 5 |
+
from torchvision.transforms import Resize
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib.cm as cm
|
| 9 |
+
import matplotlib as mpl
|
| 10 |
+
from torchvision.transforms import InterpolationMode
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from .abs_model import abs_model
|
| 14 |
+
from .blocks import *
|
| 15 |
+
from .SSN import SSN
|
| 16 |
+
from .SSN_v1 import SSN_v1
|
| 17 |
+
from .Loss.Loss import norm_loss, grad_loss
|
| 18 |
+
from .Attention_Unet import Attention_Unet
|
| 19 |
+
|
| 20 |
+
class Sparse_PH(abs_model):
|
| 21 |
+
def __init__(self, opt):
|
| 22 |
+
mid_act = opt['model']['mid_act']
|
| 23 |
+
out_act = opt['model']['out_act']
|
| 24 |
+
in_channels = opt['model']['in_channels']
|
| 25 |
+
out_channels = opt['model']['out_channels']
|
| 26 |
+
resnet = opt['model']['resnet']
|
| 27 |
+
backbone = opt['model']['backbone']
|
| 28 |
+
|
| 29 |
+
self.ncols = opt['hyper_params']['n_cols']
|
| 30 |
+
self.focal = opt['model']['focal']
|
| 31 |
+
self.clip = opt['model']['clip']
|
| 32 |
+
|
| 33 |
+
self.norm_loss_ = opt['model']['norm_loss']
|
| 34 |
+
self.grad_loss_ = opt['model']['grad_loss']
|
| 35 |
+
self.ggrad_loss_ = opt['model']['ggrad_loss']
|
| 36 |
+
self.lap_loss = opt['model']['lap_loss']
|
| 37 |
+
|
| 38 |
+
self.clip_range = opt['dataset']['linear_scale'] + opt['dataset']['linear_offset']
|
| 39 |
+
|
| 40 |
+
if backbone == 'Default':
|
| 41 |
+
self.model = SSN_v1(in_channels=in_channels,
|
| 42 |
+
out_channels=out_channels,
|
| 43 |
+
mid_act=mid_act,
|
| 44 |
+
out_act=out_act,
|
| 45 |
+
resnet=resnet)
|
| 46 |
+
elif backbone == 'ATTN':
|
| 47 |
+
self.model = Attention_Unet(in_channels, out_channels, mid_act=mid_act, out_act=out_act)
|
| 48 |
+
|
| 49 |
+
self.optimizer = get_optimizer(opt, self.model)
|
| 50 |
+
self.visualization = {}
|
| 51 |
+
|
| 52 |
+
self.norm_loss = norm_loss()
|
| 53 |
+
self.grad_loss = grad_loss()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def setup_input(self, x):
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
return self.model(x)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def compute_loss(self, y, pred):
|
| 65 |
+
b = y.shape[0]
|
| 66 |
+
|
| 67 |
+
# total_loss = avg_norm_loss(y, pred)
|
| 68 |
+
nloss = self.norm_loss.loss(y, pred) * self.norm_loss_
|
| 69 |
+
gloss = self.grad_loss.loss(pred) * self.grad_loss_
|
| 70 |
+
ggloss = self.grad_loss.gloss(y, pred) * self.ggrad_loss_
|
| 71 |
+
laploss = self.grad_loss.laploss(pred) * self.lap_loss
|
| 72 |
+
|
| 73 |
+
total_loss = nloss + gloss + ggloss + laploss
|
| 74 |
+
|
| 75 |
+
self.loss_log = {
|
| 76 |
+
'norm_loss': nloss.item(),
|
| 77 |
+
'grad_loss': gloss.item(),
|
| 78 |
+
'grad_l1_loss': ggloss.item(),
|
| 79 |
+
'lap_loss': laploss.item(),
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if self.focal:
|
| 84 |
+
total_loss = torch.pow(total_loss, 3)
|
| 85 |
+
|
| 86 |
+
return total_loss
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def supervise(self, input_x, y, is_training:bool)->float:
|
| 90 |
+
optimizer = self.optimizer
|
| 91 |
+
model = self.model
|
| 92 |
+
|
| 93 |
+
x = input_x['x']
|
| 94 |
+
|
| 95 |
+
optimizer.zero_grad()
|
| 96 |
+
pred = self.forward(x)
|
| 97 |
+
if self.clip:
|
| 98 |
+
pred = torch.clip(pred, 0.0, self.clip_range)
|
| 99 |
+
|
| 100 |
+
loss = self.compute_loss(y, pred)
|
| 101 |
+
if is_training:
|
| 102 |
+
loss.backward()
|
| 103 |
+
optimizer.step()
|
| 104 |
+
|
| 105 |
+
xc = x.shape[1]
|
| 106 |
+
for i in range(xc):
|
| 107 |
+
self.visualization['x{}'.format(i)] = x[:, i:i+1].detach()
|
| 108 |
+
|
| 109 |
+
self.visualization['y_fore'] = y[:, 0:1].detach()
|
| 110 |
+
self.visualization['y_back'] = y[:, 1:2].detach()
|
| 111 |
+
self.visualization['pred_fore'] = pred[:, 0:1].detach()
|
| 112 |
+
self.visualization['pred_back'] = pred[:, 1:2].detach()
|
| 113 |
+
|
| 114 |
+
return loss.item()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_visualize(self) -> OrderedDict:
|
| 118 |
+
""" Convert to visualization numpy array
|
| 119 |
+
"""
|
| 120 |
+
nrows = self.ncols
|
| 121 |
+
visualizations = self.visualization
|
| 122 |
+
ret_vis = OrderedDict()
|
| 123 |
+
|
| 124 |
+
for k, v in visualizations.items():
|
| 125 |
+
batch = v.shape[0]
|
| 126 |
+
n = min(nrows, batch)
|
| 127 |
+
|
| 128 |
+
plot_v = v[:n]
|
| 129 |
+
ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0)
|
| 130 |
+
ret_vis[k] = self.plasma(ret_vis[k])
|
| 131 |
+
|
| 132 |
+
return ret_vis
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_logs(self):
|
| 136 |
+
return self.loss_log
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def inference(self, x):
|
| 140 |
+
x, device = x['x'], x['device']
|
| 141 |
+
x = torch.from_numpy(x.transpose((2,0,1))).unsqueeze(dim=0).float().to(device)
|
| 142 |
+
pred = self.forward(x)
|
| 143 |
+
|
| 144 |
+
pred = pred[0].detach().cpu().numpy().transpose((1,2,0))
|
| 145 |
+
|
| 146 |
+
return pred
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def batch_inference(self, x):
|
| 150 |
+
x = x['x']
|
| 151 |
+
pred = self.forward(x)
|
| 152 |
+
return pred
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
""" Getter & Setter
|
| 156 |
+
"""
|
| 157 |
+
def get_models(self) -> dict:
|
| 158 |
+
return {'model': self.model}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_optimizers(self) -> dict:
|
| 162 |
+
return {'optimizer': self.optimizer}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def set_models(self, models: dict) :
|
| 166 |
+
# input test
|
| 167 |
+
if 'model' not in models.keys():
|
| 168 |
+
raise ValueError('{} not in self.model'.format('model'))
|
| 169 |
+
|
| 170 |
+
self.model = models['model']
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def set_optimizers(self, optimizer: dict):
|
| 174 |
+
self.optimizer = optimizer['optimizer']
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
####################
|
| 178 |
+
# Personal Methods #
|
| 179 |
+
####################
|
| 180 |
+
def plasma(self, x):
|
| 181 |
+
norm = mpl.colors.Normalize(vmin=0.0, vmax=1)
|
| 182 |
+
mapper = cm.ScalarMappable(norm=norm, cmap='plasma')
|
| 183 |
+
bimg = mapper.to_rgba(x[:,:,0])[:,:,:3]
|
| 184 |
+
|
| 185 |
+
return bimg
|
models/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SRC: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/__init__.py
|
| 2 |
+
import logging
|
| 3 |
+
import importlib
|
| 4 |
+
|
| 5 |
+
from .abs_model import abs_model
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def find_model_using_name(model_name):
|
| 9 |
+
"""Import the module "models/[model_name].py".
|
| 10 |
+
In the file, the class called DatasetNameModel() will
|
| 11 |
+
be instantiated. It has to be a subclass of BaseModel,
|
| 12 |
+
and it is case-insensitive.
|
| 13 |
+
"""
|
| 14 |
+
model_filename = "models." + model_name
|
| 15 |
+
modellib = importlib.import_module(model_filename)
|
| 16 |
+
model = None
|
| 17 |
+
|
| 18 |
+
target_model_name = model_name
|
| 19 |
+
for name, cls in modellib.__dict__.items():
|
| 20 |
+
if name.lower() == target_model_name.lower() \
|
| 21 |
+
and issubclass(cls, abs_model):
|
| 22 |
+
model = cls
|
| 23 |
+
|
| 24 |
+
if model is None:
|
| 25 |
+
err = "In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)
|
| 26 |
+
logging.error(err)
|
| 27 |
+
exit(0)
|
| 28 |
+
|
| 29 |
+
return model
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_model(opt):
|
| 33 |
+
"""Create a model given the option.
|
| 34 |
+
This funct
|
| 35 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
| 36 |
+
Example:
|
| 37 |
+
>>> from models import create_model
|
| 38 |
+
>>> model = create_model(opt)
|
| 39 |
+
"""
|
| 40 |
+
model = find_model_using_name(opt['model']['name'])
|
| 41 |
+
instance = model(opt)
|
| 42 |
+
logging.info("model [%s] was created" % type(instance).__name__)
|
| 43 |
+
return instance
|
models/__pycache__/SSN.cpython-39.pyc
ADDED
|
Binary file (4.11 kB). View file
|
|
|
models/__pycache__/SSN_Model.cpython-39.pyc
ADDED
|
Binary file (8.96 kB). View file
|
|
|
models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (1.42 kB). View file
|
|
|
models/__pycache__/abs_model.cpython-39.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
models/__pycache__/blocks.cpython-39.pyc
ADDED
|
Binary file (6.92 kB). View file
|
|
|
models/abs_model.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
|
| 4 |
+
class abs_model(ABC):
|
| 5 |
+
""" Training Related Interface
|
| 6 |
+
"""
|
| 7 |
+
@abstractmethod
|
| 8 |
+
def setup_input(self, x):
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def supervise(self, input_x, y, is_training:bool)->float:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def get_visualize(self) -> OrderedDict:
|
| 24 |
+
return {}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
""" Inference Related Interface
|
| 28 |
+
"""
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def inference(self, x):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def batch_inference(self, x):
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
""" Logging/Visualization Related Interface
|
| 40 |
+
"""
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def get_logs(self):
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
""" Getter & Setter
|
| 47 |
+
"""
|
| 48 |
+
@abstractmethod
|
| 49 |
+
def get_models(self) -> dict:
|
| 50 |
+
""" GAN may have two models
|
| 51 |
+
"""
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@abstractmethod
|
| 56 |
+
def get_optimizers(self) -> dict:
|
| 57 |
+
""" GAN may have two optimizer
|
| 58 |
+
"""
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@abstractmethod
|
| 63 |
+
def set_models(self, models) -> dict:
|
| 64 |
+
""" GAN may have two models
|
| 65 |
+
"""
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@abstractmethod
|
| 70 |
+
def set_optimizers(self, optimizers: dict):
|
| 71 |
+
""" GAN may have two optimizer
|
| 72 |
+
"""
|
| 73 |
+
pass
|
models/attention.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from inspect import isfunction
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, einsum
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from .blocks import get_norm, zero_module
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def QKV_Attention(qkv, num_heads):
|
| 11 |
+
"""
|
| 12 |
+
Apply QKV attention.
|
| 13 |
+
:param qkv: an [N x (3 * C) x T] tensor of Qs, Ks, and Vs.
|
| 14 |
+
:return: an [N x H' x T] tensor after attention.
|
| 15 |
+
"""
|
| 16 |
+
B, C, HW = qkv.shape
|
| 17 |
+
if C % 3 != 0:
|
| 18 |
+
raise ValueError('QKV shape is wrong: {}, {}, {}'.format(B, C, HW))
|
| 19 |
+
|
| 20 |
+
split_size = C // (3 * num_heads)
|
| 21 |
+
q, k, v = qkv.chunk(3, dim=1)
|
| 22 |
+
scale = 1.0/math.sqrt(math.sqrt(split_size))
|
| 23 |
+
weight = torch.einsum('bct, bcs->bts',
|
| 24 |
+
(q * scale).view(B * num_heads, split_size, HW),
|
| 25 |
+
(k * scale).view(B * num_heads, split_size, HW))
|
| 26 |
+
|
| 27 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 28 |
+
ret = torch.einsum("bts,bcs->bct", weight, v.reshape(B * num_heads, split_size, HW))
|
| 29 |
+
|
| 30 |
+
return ret.reshape(B, -1, HW)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AttentionBlock(nn.Module):
|
| 34 |
+
"""
|
| 35 |
+
https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py
|
| 36 |
+
https://github.com/whai362/PVT/blob/a24ba02c249a510581a84f821c26322534b03a10/detection/pvt_v2.py#L57
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, in_channels, num_heads, qkv_bias=False, sr_ratio=1, linear=True):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.num_heads = num_heads
|
| 43 |
+
self.norm = get_norm(in_channels, 'Group')
|
| 44 |
+
self.qkv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels * 3, kernel_size = 1)
|
| 45 |
+
|
| 46 |
+
self.proj = zero_module(nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size = 1))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
b, c, *spatial = x.shape
|
| 51 |
+
num_heads = self.num_heads
|
| 52 |
+
|
| 53 |
+
x = x.reshape(b, c, -1) # B x C x HW
|
| 54 |
+
x = self.norm(x)
|
| 55 |
+
qkv = self.qkv(x) # b x c x HW -> B x 3C x HW
|
| 56 |
+
h = QKV_Attention(qkv, num_heads)
|
| 57 |
+
h = self.proj(h)
|
| 58 |
+
|
| 59 |
+
return (x + h).reshape(b,c,*spatial) # additive attention, similar to ResNet?
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_model_size(model):
|
| 64 |
+
param_size = 0
|
| 65 |
+
for param in model.parameters():
|
| 66 |
+
param_size += param.nelement() * param.element_size()
|
| 67 |
+
|
| 68 |
+
buffer_size = 0
|
| 69 |
+
for buffer in model.buffers():
|
| 70 |
+
buffer_size += buffer.nelement() * buffer.element_size()
|
| 71 |
+
|
| 72 |
+
size_all_mb = (param_size + buffer_size) / 1024 ** 2
|
| 73 |
+
print('model size: {:.3f}MB'.format(size_all_mb))
|
| 74 |
+
# return param_size + buffer_size
|
| 75 |
+
return size_all_mb
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == '__main__':
|
| 79 |
+
model = AttentionBlock(in_channels=256, num_heads=8)
|
| 80 |
+
|
| 81 |
+
x = torch.randn(5, 256, 32, 32, dtype=torch.float32)
|
| 82 |
+
y = model(x)
|
| 83 |
+
print('{}, {}'.format(x.shape, y.shape))
|
| 84 |
+
|
| 85 |
+
get_model_size(model)
|
models/blocks.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_model_size(model):
|
| 10 |
+
param_size = 0
|
| 11 |
+
for param in model.parameters():
|
| 12 |
+
param_size += param.nelement() * param.element_size()
|
| 13 |
+
|
| 14 |
+
buffer_size = 0
|
| 15 |
+
for buffer in model.buffers():
|
| 16 |
+
buffer_size += buffer.nelement() * buffer.element_size()
|
| 17 |
+
|
| 18 |
+
size_all_mb = (param_size + buffer_size) / 1024 ** 2
|
| 19 |
+
print('model size: {:.3f}MB'.format(size_all_mb))
|
| 20 |
+
# return param_size + buffer_size
|
| 21 |
+
return size_all_mb
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def weights_init(init_type='gaussian'):
|
| 25 |
+
def init_fun(m):
|
| 26 |
+
classname = m.__class__.__name__
|
| 27 |
+
if (classname.find('Conv') == 0 or classname.find(
|
| 28 |
+
'Linear') == 0) and hasattr(m, 'weight'):
|
| 29 |
+
if init_type == 'gaussian':
|
| 30 |
+
nn.init.normal_(m.weight, 0.0, 0.02)
|
| 31 |
+
elif init_type == 'xavier':
|
| 32 |
+
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
|
| 33 |
+
elif init_type == 'kaiming':
|
| 34 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
| 35 |
+
elif init_type == 'orthogonal':
|
| 36 |
+
nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
|
| 37 |
+
elif init_type == 'default':
|
| 38 |
+
pass
|
| 39 |
+
else:
|
| 40 |
+
assert 0, "Unsupported initialization: {}".format(init_type)
|
| 41 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 42 |
+
nn.init.constant_(m.bias, 0.0)
|
| 43 |
+
|
| 44 |
+
return init_fun
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def freeze(module):
|
| 48 |
+
for param in module.parameters():
|
| 49 |
+
param.requires_grad = False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def unfreeze(module):
|
| 53 |
+
for param in module.parameters():
|
| 54 |
+
param.requires_grad = True
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_optimizer(opt, model):
|
| 58 |
+
lr = float(opt['hyper_params']['lr'])
|
| 59 |
+
beta1 = float(opt['model']['beta1'])
|
| 60 |
+
weight_decay = float(opt['model']['weight_decay'])
|
| 61 |
+
opt_name = opt['model']['optimizer']
|
| 62 |
+
|
| 63 |
+
optim_params = []
|
| 64 |
+
# weight decay
|
| 65 |
+
for key, value in model.named_parameters():
|
| 66 |
+
if not value.requires_grad:
|
| 67 |
+
continue # frozen weights
|
| 68 |
+
|
| 69 |
+
if key[-4:] == 'bias':
|
| 70 |
+
optim_params += [{'params': value, 'weight_decay': 0.0}]
|
| 71 |
+
else:
|
| 72 |
+
optim_params += [{'params': value,
|
| 73 |
+
'weight_decay': weight_decay}]
|
| 74 |
+
|
| 75 |
+
if opt_name == 'Adam':
|
| 76 |
+
return optim.Adam(optim_params,
|
| 77 |
+
lr=lr,
|
| 78 |
+
betas=(beta1, 0.999),
|
| 79 |
+
eps=1e-5)
|
| 80 |
+
else:
|
| 81 |
+
err = '{} not implemented yet'.format(opt_name)
|
| 82 |
+
logging.error(err)
|
| 83 |
+
raise NotImplementedError(err)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_activation(activation):
|
| 87 |
+
act_func = {
|
| 88 |
+
'relu':nn.ReLU(),
|
| 89 |
+
'sigmoid':nn.Sigmoid(),
|
| 90 |
+
'tanh':nn.Tanh(),
|
| 91 |
+
'prelu':nn.PReLU(),
|
| 92 |
+
'leaky_relu':nn.LeakyReLU(0.2),
|
| 93 |
+
'gelu':nn.GELU(),
|
| 94 |
+
}
|
| 95 |
+
if activation not in act_func.keys():
|
| 96 |
+
logging.error("activation {} is not implemented yet".format(activation))
|
| 97 |
+
assert False
|
| 98 |
+
|
| 99 |
+
return act_func[activation]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_norm(out_channels, norm_type='Group', groups=32):
|
| 103 |
+
norm_set = ['Instance', 'Batch', 'Group']
|
| 104 |
+
if norm_type not in norm_set:
|
| 105 |
+
err = "Normalization {} has not been implemented yet"
|
| 106 |
+
logging.error(err)
|
| 107 |
+
raise ValueError(err)
|
| 108 |
+
|
| 109 |
+
if norm_type == 'Instance':
|
| 110 |
+
return nn.InstanceNorm2d(out_channels, affine=True)
|
| 111 |
+
|
| 112 |
+
if norm_type == 'Batch':
|
| 113 |
+
return nn.BatchNorm2d(out_channels)
|
| 114 |
+
|
| 115 |
+
if norm_type == 'Group':
|
| 116 |
+
if out_channels >= 32:
|
| 117 |
+
groups = 32
|
| 118 |
+
else:
|
| 119 |
+
groups = max(out_channels // 2, 1)
|
| 120 |
+
|
| 121 |
+
return nn.GroupNorm(groups, out_channels)
|
| 122 |
+
else:
|
| 123 |
+
raise NotImplementedError
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class Conv(nn.Module):
|
| 127 |
+
def __init__(self, in_channels, out_channels, stride=1, norm_type='Batch', activation='relu'):
|
| 128 |
+
super().__init__()
|
| 129 |
+
|
| 130 |
+
act_func = get_activation(activation)
|
| 131 |
+
norm_layer = get_norm(out_channels, norm_type)
|
| 132 |
+
self.conv = nn.Sequential(
|
| 133 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True, padding_mode='reflect'),
|
| 134 |
+
norm_layer,
|
| 135 |
+
act_func)
|
| 136 |
+
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
return self.conv(x)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def zero_module(module):
|
| 142 |
+
"""
|
| 143 |
+
Zero out the parameters of a module and return it.
|
| 144 |
+
"""
|
| 145 |
+
for p in module.parameters():
|
| 146 |
+
p.detach().zero_()
|
| 147 |
+
return module
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Up(nn.Module):
|
| 151 |
+
def __init__(self):
|
| 152 |
+
super().__init__()
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
return F.interpolate(x, scale_factor=2, mode='bilinear')
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class Down(nn.Module):
|
| 160 |
+
def __init__(self, channels, use_conv):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.use_conv = use_conv
|
| 163 |
+
|
| 164 |
+
if self.use_conv:
|
| 165 |
+
self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
|
| 166 |
+
else:
|
| 167 |
+
self.op = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def forward(self, x):
|
| 171 |
+
return self.op(x)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class Res_Type(Enum):
|
| 175 |
+
UP = 1
|
| 176 |
+
DOWN = 2
|
| 177 |
+
SAME = 3
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class ResBlock(nn.Module):
|
| 181 |
+
def __init__(self, in_channels: int, out_channels: int, dropout=0.0, updown=Res_Type.DOWN, mid_act='leaky'):
|
| 182 |
+
""" ResBlock to cover several cases:
|
| 183 |
+
1. Up/Down/Same
|
| 184 |
+
2. in_channels != out_channels
|
| 185 |
+
"""
|
| 186 |
+
super().__init__()
|
| 187 |
+
|
| 188 |
+
self.updown = updown
|
| 189 |
+
|
| 190 |
+
self.in_norm = get_norm(out_channels, 'Group')
|
| 191 |
+
self.in_act = get_activation(mid_act)
|
| 192 |
+
self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=True)
|
| 193 |
+
|
| 194 |
+
# up down
|
| 195 |
+
if self.updown == Res_Type.DOWN:
|
| 196 |
+
self.h_updown = Down(in_channels, use_conv=True)
|
| 197 |
+
self.x_updown = Down(in_channels, use_conv=True)
|
| 198 |
+
elif self.updown == Res_Type.UP:
|
| 199 |
+
self.h_updown = Up()
|
| 200 |
+
self.x_updown = Up()
|
| 201 |
+
else:
|
| 202 |
+
self.h_updown = nn.Identity()
|
| 203 |
+
|
| 204 |
+
self.out_layer = nn.Sequential(
|
| 205 |
+
get_norm(out_channels, 'Group'),
|
| 206 |
+
get_activation(mid_act),
|
| 207 |
+
nn.Dropout(p=dropout),
|
| 208 |
+
zero_module(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True))
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def forward(self, x):
|
| 213 |
+
# in layer
|
| 214 |
+
h = self.in_act(self.in_norm(x))
|
| 215 |
+
h = self.in_conv(self.h_updown(h))
|
| 216 |
+
x = self.x_updown(x)
|
| 217 |
+
|
| 218 |
+
# out layer
|
| 219 |
+
h = self.out_layer(h)
|
| 220 |
+
return x + h
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
if __name__ == '__main__':
|
| 225 |
+
x = torch.randn(5, 3, 256, 256)
|
| 226 |
+
up = Up()
|
| 227 |
+
conv_down = Down(3, True)
|
| 228 |
+
pool_down = Down(3, False)
|
| 229 |
+
|
| 230 |
+
print('Up: {}'.format(up(x).shape))
|
| 231 |
+
print('Conv down: {}'.format(conv_down(x).shape))
|
| 232 |
+
print('Pool down: {}'.format(pool_down(x).shape))
|
| 233 |
+
|
| 234 |
+
up_model = ResBlock(3, 6, updown=True)
|
| 235 |
+
down_model = ResBlock(3, 6, updown=False)
|
| 236 |
+
|
| 237 |
+
print('model down: {}'.format(up_model(x).shape))
|
| 238 |
+
print('model down: {}'.format(down_model(x).shape))
|
models/pvt_attention.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 7 |
+
from timm.models.registry import register_model
|
| 8 |
+
from timm.models.vision_transformer import _cfg
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DWConv(nn.Module):
|
| 13 |
+
def __init__(self, dim=768):
|
| 14 |
+
super(DWConv, self).__init__()
|
| 15 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
| 16 |
+
|
| 17 |
+
def forward(self, x, H, W):
|
| 18 |
+
B, N, C = x.shape
|
| 19 |
+
x = x.transpose(1, 2).view(B, C, H, W)
|
| 20 |
+
x = self.dwconv(x)
|
| 21 |
+
x = x.flatten(2).transpose(1, 2)
|
| 22 |
+
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Mlp(nn.Module):
|
| 27 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
|
| 28 |
+
super().__init__()
|
| 29 |
+
out_features = out_features or in_features
|
| 30 |
+
hidden_features = hidden_features or in_features
|
| 31 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 32 |
+
self.dwconv = DWConv(hidden_features)
|
| 33 |
+
self.act = act_layer()
|
| 34 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 35 |
+
self.drop = nn.Dropout(drop)
|
| 36 |
+
self.linear = linear
|
| 37 |
+
if self.linear:
|
| 38 |
+
self.relu = nn.ReLU(inplace=True)
|
| 39 |
+
self.apply(self._init_weights)
|
| 40 |
+
|
| 41 |
+
def _init_weights(self, m):
|
| 42 |
+
if isinstance(m, nn.Linear):
|
| 43 |
+
trunc_normal_(m.weight, std=.02)
|
| 44 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 45 |
+
nn.init.constant_(m.bias, 0)
|
| 46 |
+
elif isinstance(m, nn.LayerNorm):
|
| 47 |
+
nn.init.constant_(m.bias, 0)
|
| 48 |
+
nn.init.constant_(m.weight, 1.0)
|
| 49 |
+
elif isinstance(m, nn.Conv2d):
|
| 50 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 51 |
+
fan_out //= m.groups
|
| 52 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 53 |
+
if m.bias is not None:
|
| 54 |
+
m.bias.data.zero_()
|
| 55 |
+
|
| 56 |
+
def forward(self, x, H, W):
|
| 57 |
+
x = self.fc1(x)
|
| 58 |
+
if self.linear:
|
| 59 |
+
x = self.relu(x)
|
| 60 |
+
x = self.dwconv(x, H, W)
|
| 61 |
+
x = self.act(x)
|
| 62 |
+
x = self.drop(x)
|
| 63 |
+
x = self.fc2(x)
|
| 64 |
+
x = self.drop(x)
|
| 65 |
+
return x
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Attention(nn.Module):
|
| 69 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
|
| 70 |
+
super().__init__()
|
| 71 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
| 72 |
+
|
| 73 |
+
self.dim = dim
|
| 74 |
+
self.num_heads = num_heads
|
| 75 |
+
head_dim = dim // num_heads
|
| 76 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 77 |
+
|
| 78 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
| 79 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
| 80 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 81 |
+
self.proj = nn.Linear(dim, dim)
|
| 82 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 83 |
+
|
| 84 |
+
self.linear = linear
|
| 85 |
+
self.sr_ratio = sr_ratio
|
| 86 |
+
if not linear:
|
| 87 |
+
if sr_ratio > 1:
|
| 88 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
| 89 |
+
self.norm = nn.LayerNorm(dim)
|
| 90 |
+
else:
|
| 91 |
+
self.pool = nn.AdaptiveAvgPool2d(7)
|
| 92 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
|
| 93 |
+
self.norm = nn.LayerNorm(dim)
|
| 94 |
+
self.act = nn.GELU()
|
| 95 |
+
self.apply(self._init_weights)
|
| 96 |
+
|
| 97 |
+
def _init_weights(self, m):
|
| 98 |
+
if isinstance(m, nn.Linear):
|
| 99 |
+
trunc_normal_(m.weight, std=.02)
|
| 100 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 101 |
+
nn.init.constant_(m.bias, 0)
|
| 102 |
+
elif isinstance(m, nn.LayerNorm):
|
| 103 |
+
nn.init.constant_(m.bias, 0)
|
| 104 |
+
nn.init.constant_(m.weight, 1.0)
|
| 105 |
+
elif isinstance(m, nn.Conv2d):
|
| 106 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 107 |
+
fan_out //= m.groups
|
| 108 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 109 |
+
if m.bias is not None:
|
| 110 |
+
m.bias.data.zero_()
|
| 111 |
+
|
| 112 |
+
def forward(self, x, H, W):
|
| 113 |
+
B, N, C = x.shape
|
| 114 |
+
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 115 |
+
|
| 116 |
+
if not self.linear:
|
| 117 |
+
if self.sr_ratio > 1:
|
| 118 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
| 119 |
+
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
| 120 |
+
x_ = self.norm(x_)
|
| 121 |
+
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 122 |
+
else:
|
| 123 |
+
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 124 |
+
else:
|
| 125 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
| 126 |
+
x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
|
| 127 |
+
x_ = self.norm(x_)
|
| 128 |
+
x_ = self.act(x_)
|
| 129 |
+
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 130 |
+
k, v = kv[0], kv[1]
|
| 131 |
+
|
| 132 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 133 |
+
attn = attn.softmax(dim=-1)
|
| 134 |
+
attn = self.attn_drop(attn)
|
| 135 |
+
|
| 136 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 137 |
+
x = self.proj(x)
|
| 138 |
+
x = self.proj_drop(x)
|
| 139 |
+
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class Block(nn.Module):
|
| 144 |
+
|
| 145 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 146 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.norm1 = norm_layer(dim)
|
| 149 |
+
self.attn = Attention(
|
| 150 |
+
dim,
|
| 151 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 152 |
+
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
|
| 153 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 154 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 155 |
+
self.norm2 = norm_layer(dim)
|
| 156 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 157 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)
|
| 158 |
+
|
| 159 |
+
self.apply(self._init_weights)
|
| 160 |
+
|
| 161 |
+
def _init_weights(self, m):
|
| 162 |
+
if isinstance(m, nn.Linear):
|
| 163 |
+
trunc_normal_(m.weight, std=.02)
|
| 164 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 165 |
+
nn.init.constant_(m.bias, 0)
|
| 166 |
+
elif isinstance(m, nn.LayerNorm):
|
| 167 |
+
nn.init.constant_(m.bias, 0)
|
| 168 |
+
nn.init.constant_(m.weight, 1.0)
|
| 169 |
+
elif isinstance(m, nn.Conv2d):
|
| 170 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 171 |
+
fan_out //= m.groups
|
| 172 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 173 |
+
if m.bias is not None:
|
| 174 |
+
m.bias.data.zero_()
|
| 175 |
+
|
| 176 |
+
def forward(self, x, H, W):
|
| 177 |
+
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
| 178 |
+
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
| 179 |
+
|
| 180 |
+
return x
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class OverlapPatchEmbed(nn.Module):
|
| 184 |
+
""" Image to Patch Embedding
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
| 188 |
+
super().__init__()
|
| 189 |
+
img_size = to_2tuple(img_size)
|
| 190 |
+
patch_size = to_2tuple(patch_size)
|
| 191 |
+
|
| 192 |
+
assert max(patch_size) > stride, "Set larger patch_size than stride"
|
| 193 |
+
|
| 194 |
+
self.img_size = img_size
|
| 195 |
+
self.patch_size = patch_size
|
| 196 |
+
self.H, self.W = img_size[0] // stride, img_size[1] // stride
|
| 197 |
+
self.num_patches = self.H * self.W
|
| 198 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
| 199 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
| 200 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 201 |
+
|
| 202 |
+
self.apply(self._init_weights)
|
| 203 |
+
|
| 204 |
+
def _init_weights(self, m):
|
| 205 |
+
if isinstance(m, nn.Linear):
|
| 206 |
+
trunc_normal_(m.weight, std=.02)
|
| 207 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 208 |
+
nn.init.constant_(m.bias, 0)
|
| 209 |
+
elif isinstance(m, nn.LayerNorm):
|
| 210 |
+
nn.init.constant_(m.bias, 0)
|
| 211 |
+
nn.init.constant_(m.weight, 1.0)
|
| 212 |
+
elif isinstance(m, nn.Conv2d):
|
| 213 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 214 |
+
fan_out //= m.groups
|
| 215 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 216 |
+
if m.bias is not None:
|
| 217 |
+
m.bias.data.zero_()
|
| 218 |
+
|
| 219 |
+
def forward(self, x):
|
| 220 |
+
x = self.proj(x)
|
| 221 |
+
_, _, H, W = x.shape
|
| 222 |
+
import pdb; pdb.set_trace()
|
| 223 |
+
x = x.flatten(2).transpose(1, 2)
|
| 224 |
+
x = self.norm(x)
|
| 225 |
+
|
| 226 |
+
return x, H, W
|
| 227 |
+
|
| 228 |
+
if __name__ == '__main__':
|
| 229 |
+
test = torch.randn(5, 3, 224, 224)
|
| 230 |
+
|
| 231 |
+
embed_dim = 768
|
| 232 |
+
patch_embed = OverlapPatchEmbed(embed_dim=embed_dim)
|
| 233 |
+
block = Block(embed_dim, 1)
|
| 234 |
+
|
| 235 |
+
import pdb; pdb.set_trace()
|
| 236 |
+
print('x: {}'.format(test.shape))
|
| 237 |
+
pe, H, W = patch_embed(test)
|
| 238 |
+
print('After patch: {}'.format(pe.shape))
|
| 239 |
+
y = block(pe, H, W)
|
| 240 |
+
print('After block: {}'.format(y.shape))
|
models/template.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision import utils
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
from .abs_model import abs_model
|
| 8 |
+
from .blocks import *
|
| 9 |
+
from .Loss.Loss import avg_norm_loss
|
| 10 |
+
|
| 11 |
+
class Template(abs_model):
|
| 12 |
+
""" Standard Unet Implementation
|
| 13 |
+
src: https://arxiv.org/pdf/1505.04597.pdf
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, opt):
|
| 16 |
+
resunet = opt['model']['resunet']
|
| 17 |
+
out_act = opt['model']['out_act']
|
| 18 |
+
norm_type = opt['model']['norm_type']
|
| 19 |
+
in_channels = opt['model']['in_channels']
|
| 20 |
+
out_channels = opt['model']['out_channels']
|
| 21 |
+
self.ncols = opt['hyper_params']['n_cols']
|
| 22 |
+
|
| 23 |
+
self.model = Unet(in_channels=in_channels,
|
| 24 |
+
out_channels=out_channels,
|
| 25 |
+
norm_type=norm_type,
|
| 26 |
+
out_act=out_act,
|
| 27 |
+
resunet=resunet)
|
| 28 |
+
|
| 29 |
+
self.optimizer = get_optimizer(opt, self.model)
|
| 30 |
+
self.visualization = {}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def setup_input(self, x):
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
return self.model(x)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def compute_loss(self, y, pred):
|
| 42 |
+
return avg_norm_loss(y, pred)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def supervise(self, input_x, y, is_training:bool)->float:
|
| 46 |
+
optimizer = self.optimizer
|
| 47 |
+
model = self.model
|
| 48 |
+
|
| 49 |
+
optimizer.zero_grad()
|
| 50 |
+
pred = model(input_x)
|
| 51 |
+
loss = self.compute_loss(y, pred)
|
| 52 |
+
|
| 53 |
+
if is_training:
|
| 54 |
+
loss.backward()
|
| 55 |
+
optimizer.step()
|
| 56 |
+
|
| 57 |
+
self.visualization['y'] = pred.detach()
|
| 58 |
+
self.visualization['pred'] = pred.detach()
|
| 59 |
+
|
| 60 |
+
return loss.item()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_visualize(self) -> OrderedDict:
|
| 64 |
+
""" Convert to visualization numpy array
|
| 65 |
+
"""
|
| 66 |
+
nrows = self.ncols
|
| 67 |
+
visualizations = self.visualization
|
| 68 |
+
ret_vis = OrderedDict()
|
| 69 |
+
|
| 70 |
+
for k, v in visualizations.items():
|
| 71 |
+
batch = v.shape[0]
|
| 72 |
+
n = min(nrows, batch)
|
| 73 |
+
|
| 74 |
+
plot_v = v[:n]
|
| 75 |
+
ret_vis[k] = utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0)
|
| 76 |
+
|
| 77 |
+
return ret_vis
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def inference(self, x):
|
| 81 |
+
# TODO
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def batch_inference(self, x):
|
| 86 |
+
# TODO
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
""" Getter & Setter
|
| 91 |
+
"""
|
| 92 |
+
def get_models(self) -> dict:
|
| 93 |
+
return {'model': self.model}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_optimizers(self) -> dict:
|
| 97 |
+
return {'optimizer': self.optimizer}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def set_models(self, models: dict) :
|
| 101 |
+
# input test
|
| 102 |
+
if 'model' not in models.keys():
|
| 103 |
+
raise ValueError('{} not in self.model'.format('model'))
|
| 104 |
+
|
| 105 |
+
self.model = models['model']
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def set_optimizers(self, optimizer: dict):
|
| 109 |
+
self.optimizer = optimizer['optimizer']
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
####################
|
| 113 |
+
# Personal Methods #
|
| 114 |
+
####################
|
weights/SSN/0000001760.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:44328317fa836804554ae453fe1492a45cff724b5c13b5070211d6d860096089
|
| 3 |
+
size 283511041
|