Johannes Kolbe
commited on
Commit
·
ed6b6d6
1
Parent(s):
dd2f594
enable model loading from hf hub
Browse files- .gitignore +1 -0
- .ipynb_checkpoints/model_to_hf_hub-checkpoint.ipynb +255 -0
- app.py +2 -2
- interface.py +4 -5
- model_to_hf_hub.ipynb +297 -0
- models/model_zoo.py +5 -8
- models/pggan_generator.py +48 -2
- models/stylegan2_generator.py +44 -2
- models/stylegan_generator.py +49 -2
- utils.py +19 -13
.gitignore
CHANGED
|
@@ -20,6 +20,7 @@ __pycache__/
|
|
| 20 |
*.zip
|
| 21 |
events.*
|
| 22 |
|
|
|
|
| 23 |
*.pkl
|
| 24 |
*.h5
|
| 25 |
*.dat
|
|
|
|
| 20 |
*.zip
|
| 21 |
events.*
|
| 22 |
|
| 23 |
+
/checkpoints/
|
| 24 |
*.pkl
|
| 25 |
*.h5
|
| 26 |
*.dat
|
.ipynb_checkpoints/model_to_hf_hub-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 15,
|
| 6 |
+
"metadata": {
|
| 7 |
+
"pycharm": {
|
| 8 |
+
"name": "#%%\n"
|
| 9 |
+
}
|
| 10 |
+
},
|
| 11 |
+
"outputs": [],
|
| 12 |
+
"source": [
|
| 13 |
+
"import huggingface_hub\n",
|
| 14 |
+
"import utils"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": 16,
|
| 20 |
+
"metadata": {
|
| 21 |
+
"pycharm": {
|
| 22 |
+
"name": "#%%\n"
|
| 23 |
+
}
|
| 24 |
+
},
|
| 25 |
+
"outputs": [
|
| 26 |
+
{
|
| 27 |
+
"data": {
|
| 28 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 29 |
+
"model_id": "525a0eaa021f4fdebd9138f4e7c5ab65",
|
| 30 |
+
"version_major": 2,
|
| 31 |
+
"version_minor": 0
|
| 32 |
+
},
|
| 33 |
+
"text/plain": [
|
| 34 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"output_type": "display_data"
|
| 39 |
+
}
|
| 40 |
+
],
|
| 41 |
+
"source": [
|
| 42 |
+
"huggingface_hub.notebook_login()"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": 13,
|
| 48 |
+
"metadata": {
|
| 49 |
+
"pycharm": {
|
| 50 |
+
"name": "#%%\n"
|
| 51 |
+
}
|
| 52 |
+
},
|
| 53 |
+
"outputs": [
|
| 54 |
+
{
|
| 55 |
+
"name": "stdout",
|
| 56 |
+
"output_type": "stream",
|
| 57 |
+
"text": [
|
| 58 |
+
"Building generator for model `stylegan_animeface512` ...\n",
|
| 59 |
+
"Finish building generator.\n",
|
| 60 |
+
"Loading checkpoint from `checkpoints/stylegan_animeface512.pth` ...\n",
|
| 61 |
+
"Finish loading checkpoint.\n"
|
| 62 |
+
]
|
| 63 |
+
}
|
| 64 |
+
],
|
| 65 |
+
"source": [
|
| 66 |
+
"animeface_model = utils.load_generator('stylegan_animeface512')"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "code",
|
| 71 |
+
"execution_count": 5,
|
| 72 |
+
"metadata": {
|
| 73 |
+
"pycharm": {
|
| 74 |
+
"name": "#%%\n"
|
| 75 |
+
}
|
| 76 |
+
},
|
| 77 |
+
"outputs": [
|
| 78 |
+
{
|
| 79 |
+
"name": "stderr",
|
| 80 |
+
"output_type": "stream",
|
| 81 |
+
"text": [
|
| 82 |
+
"Cloning https://huggingface.co/johko/stylegan_animeface512 into local empty directory.\n"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"data": {
|
| 87 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 88 |
+
"model_id": "6e51c5ae4a504617aa0f1c1ac798ed15",
|
| 89 |
+
"version_major": 2,
|
| 90 |
+
"version_minor": 0
|
| 91 |
+
},
|
| 92 |
+
"text/plain": [
|
| 93 |
+
"Upload file pytorch_model.bin: 0%| | 32.0k/103M [00:00<?, ?B/s]"
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"output_type": "display_data"
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"name": "stderr",
|
| 101 |
+
"output_type": "stream",
|
| 102 |
+
"text": [
|
| 103 |
+
"To https://huggingface.co/johko/stylegan_animeface512\n",
|
| 104 |
+
" 750cd03..2841156 main -> main\n",
|
| 105 |
+
"\n"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"data": {
|
| 110 |
+
"text/plain": [
|
| 111 |
+
"'https://huggingface.co/johko/stylegan_animeface512/commit/2841156bad3c5a5f47f3edbf4a41880ea8fd3ad3'"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
"execution_count": 5,
|
| 115 |
+
"metadata": {},
|
| 116 |
+
"output_type": "execute_result"
|
| 117 |
+
}
|
| 118 |
+
],
|
| 119 |
+
"source": [
|
| 120 |
+
"animeface_model.push_to_hub(\"johko/stylegan_animeface512\")"
|
| 121 |
+
]
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"cell_type": "code",
|
| 125 |
+
"execution_count": 11,
|
| 126 |
+
"metadata": {
|
| 127 |
+
"pycharm": {
|
| 128 |
+
"name": "#%%\n"
|
| 129 |
+
}
|
| 130 |
+
},
|
| 131 |
+
"outputs": [
|
| 132 |
+
{
|
| 133 |
+
"name": "stdout",
|
| 134 |
+
"output_type": "stream",
|
| 135 |
+
"text": [
|
| 136 |
+
"Building generator for model `pggan_celebahq1024` ...\n",
|
| 137 |
+
"Finish building generator.\n",
|
| 138 |
+
"Loading checkpoint from `checkpoints/pggan_celebahq1024.pth` ...\n",
|
| 139 |
+
"Finish loading checkpoint.\n"
|
| 140 |
+
]
|
| 141 |
+
}
|
| 142 |
+
],
|
| 143 |
+
"source": [
|
| 144 |
+
"celebhq_model = utils.load_generator(\"pggan_celebahq1024\")"
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"cell_type": "code",
|
| 149 |
+
"execution_count": 7,
|
| 150 |
+
"metadata": {
|
| 151 |
+
"pycharm": {
|
| 152 |
+
"name": "#%%\n"
|
| 153 |
+
}
|
| 154 |
+
},
|
| 155 |
+
"outputs": [
|
| 156 |
+
{
|
| 157 |
+
"name": "stderr",
|
| 158 |
+
"output_type": "stream",
|
| 159 |
+
"text": [
|
| 160 |
+
"Cloning https://huggingface.co/johko/pggan-celebahq-1024 into local empty directory.\n"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"data": {
|
| 165 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 166 |
+
"model_id": "ef4086b23a654b079bd6a3678140c50d",
|
| 167 |
+
"version_major": 2,
|
| 168 |
+
"version_minor": 0
|
| 169 |
+
},
|
| 170 |
+
"text/plain": [
|
| 171 |
+
"Upload file pytorch_model.bin: 0%| | 32.0k/88.1M [00:00<?, ?B/s]"
|
| 172 |
+
]
|
| 173 |
+
},
|
| 174 |
+
"metadata": {},
|
| 175 |
+
"output_type": "display_data"
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"name": "stderr",
|
| 179 |
+
"output_type": "stream",
|
| 180 |
+
"text": [
|
| 181 |
+
"To https://huggingface.co/johko/pggan-celebahq-1024\n",
|
| 182 |
+
" 780695e..278449f main -> main\n",
|
| 183 |
+
"\n"
|
| 184 |
+
]
|
| 185 |
+
},
|
| 186 |
+
{
|
| 187 |
+
"data": {
|
| 188 |
+
"text/plain": [
|
| 189 |
+
"'https://huggingface.co/johko/pggan-celebahq-1024/commit/278449f8416d38a0233c980774528d32c4eee99c'"
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
"execution_count": 7,
|
| 193 |
+
"metadata": {},
|
| 194 |
+
"output_type": "execute_result"
|
| 195 |
+
}
|
| 196 |
+
],
|
| 197 |
+
"source": [
|
| 198 |
+
"celebhq_model.push_to_hub(\"johko/pggan-celebahq-1024\")"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "code",
|
| 203 |
+
"execution_count": 17,
|
| 204 |
+
"metadata": {},
|
| 205 |
+
"outputs": [
|
| 206 |
+
{
|
| 207 |
+
"name": "stdout",
|
| 208 |
+
"output_type": "stream",
|
| 209 |
+
"text": [
|
| 210 |
+
"Building generator for model `stylegan_car512` ...\n",
|
| 211 |
+
"Finish building generator.\n",
|
| 212 |
+
"Loading checkpoint from `checkpoints/stylegan_car512.pth` ...\n",
|
| 213 |
+
"Finish loading checkpoint.\n"
|
| 214 |
+
]
|
| 215 |
+
}
|
| 216 |
+
],
|
| 217 |
+
"source": [
|
| 218 |
+
"cars_model = utils.load_generator(\"stylegan_car512\")"
|
| 219 |
+
]
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"cell_type": "code",
|
| 223 |
+
"execution_count": null,
|
| 224 |
+
"metadata": {},
|
| 225 |
+
"outputs": [],
|
| 226 |
+
"source": [
|
| 227 |
+
"cars_model.push_to_hub(\"johko/stylegan_car512\")"
|
| 228 |
+
]
|
| 229 |
+
}
|
| 230 |
+
],
|
| 231 |
+
"metadata": {
|
| 232 |
+
"interpreter": {
|
| 233 |
+
"hash": "a8d699d01f596cc27ac2722fbc0550b939d217978c7e1ca888dca7ba146ee4bf"
|
| 234 |
+
},
|
| 235 |
+
"kernelspec": {
|
| 236 |
+
"display_name": "Python 3",
|
| 237 |
+
"language": "python",
|
| 238 |
+
"name": "python3"
|
| 239 |
+
},
|
| 240 |
+
"language_info": {
|
| 241 |
+
"codemirror_mode": {
|
| 242 |
+
"name": "ipython",
|
| 243 |
+
"version": 3
|
| 244 |
+
},
|
| 245 |
+
"file_extension": ".py",
|
| 246 |
+
"mimetype": "text/x-python",
|
| 247 |
+
"name": "python",
|
| 248 |
+
"nbconvert_exporter": "python",
|
| 249 |
+
"pygments_lexer": "ipython3",
|
| 250 |
+
"version": "3.9.9"
|
| 251 |
+
}
|
| 252 |
+
},
|
| 253 |
+
"nbformat": 4,
|
| 254 |
+
"nbformat_minor": 2
|
| 255 |
+
}
|
app.py
CHANGED
|
@@ -16,7 +16,7 @@ from utils import factorize_weight
|
|
| 16 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
| 17 |
def get_model(model_name):
|
| 18 |
"""Gets model by name."""
|
| 19 |
-
return load_generator(model_name)
|
| 20 |
|
| 21 |
|
| 22 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
|
@@ -72,7 +72,7 @@ layer_idx = st.sidebar.selectbox(
|
|
| 72 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
| 73 |
|
| 74 |
num_semantics = st.sidebar.number_input(
|
| 75 |
-
'Number of semantics', value=
|
| 76 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
| 77 |
if gan_type == 'pggan':
|
| 78 |
max_step = 5.0
|
|
|
|
| 16 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
| 17 |
def get_model(model_name):
|
| 18 |
"""Gets model by name."""
|
| 19 |
+
return load_generator(model_name, from_hf_hub=True)
|
| 20 |
|
| 21 |
|
| 22 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
|
|
|
| 72 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
| 73 |
|
| 74 |
num_semantics = st.sidebar.number_input(
|
| 75 |
+
'Number of semantics', value=5, min_value=0, max_value=None, step=1)
|
| 76 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
| 77 |
if gan_type == 'pggan':
|
| 78 |
max_step = 5.0
|
interface.py
CHANGED
|
@@ -16,7 +16,7 @@ from utils import factorize_weight
|
|
| 16 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
| 17 |
def get_model(model_name):
|
| 18 |
"""Gets model by name."""
|
| 19 |
-
return load_generator(model_name)
|
| 20 |
|
| 21 |
|
| 22 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
|
@@ -27,7 +27,7 @@ def factorize_model(model, layer_idx):
|
|
| 27 |
|
| 28 |
def sample(model, gan_type, num=1):
|
| 29 |
"""Samples latent codes."""
|
| 30 |
-
codes = torch.randn(num, model.z_space_dim)
|
| 31 |
if gan_type == 'pggan':
|
| 32 |
codes = model.layer0.pixel_norm(codes)
|
| 33 |
elif gan_type == 'stylegan':
|
|
@@ -63,8 +63,7 @@ def main():
|
|
| 63 |
|
| 64 |
model_name = st.sidebar.selectbox(
|
| 65 |
'Model to Interpret',
|
| 66 |
-
['pggan_celebahq1024', 'stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256'
|
| 67 |
-
])
|
| 68 |
|
| 69 |
model = get_model(model_name)
|
| 70 |
gan_type = parse_gan_type(model)
|
|
@@ -74,7 +73,7 @@ def main():
|
|
| 74 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
| 75 |
|
| 76 |
num_semantics = st.sidebar.number_input(
|
| 77 |
-
'Number of semantics', value=
|
| 78 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
| 79 |
if gan_type == 'pggan':
|
| 80 |
max_step = 5.0
|
|
|
|
| 16 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
| 17 |
def get_model(model_name):
|
| 18 |
"""Gets model by name."""
|
| 19 |
+
return load_generator(model_name, from_hf_hub=True)
|
| 20 |
|
| 21 |
|
| 22 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
|
|
|
| 27 |
|
| 28 |
def sample(model, gan_type, num=1):
|
| 29 |
"""Samples latent codes."""
|
| 30 |
+
codes = torch.randn(num, model.z_space_dim)
|
| 31 |
if gan_type == 'pggan':
|
| 32 |
codes = model.layer0.pixel_norm(codes)
|
| 33 |
elif gan_type == 'stylegan':
|
|
|
|
| 63 |
|
| 64 |
model_name = st.sidebar.selectbox(
|
| 65 |
'Model to Interpret',
|
| 66 |
+
['pggan_celebahq1024', 'stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256',])
|
|
|
|
| 67 |
|
| 68 |
model = get_model(model_name)
|
| 69 |
gan_type = parse_gan_type(model)
|
|
|
|
| 73 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
| 74 |
|
| 75 |
num_semantics = st.sidebar.number_input(
|
| 76 |
+
'Number of semantics', value=5, min_value=0, max_value=None, step=1)
|
| 77 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
| 78 |
if gan_type == 'pggan':
|
| 79 |
max_step = 5.0
|
model_to_hf_hub.ipynb
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 15,
|
| 6 |
+
"metadata": {
|
| 7 |
+
"pycharm": {
|
| 8 |
+
"name": "#%%\n"
|
| 9 |
+
}
|
| 10 |
+
},
|
| 11 |
+
"outputs": [],
|
| 12 |
+
"source": [
|
| 13 |
+
"import huggingface_hub\n",
|
| 14 |
+
"import utils"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": 16,
|
| 20 |
+
"metadata": {
|
| 21 |
+
"pycharm": {
|
| 22 |
+
"name": "#%%\n"
|
| 23 |
+
}
|
| 24 |
+
},
|
| 25 |
+
"outputs": [
|
| 26 |
+
{
|
| 27 |
+
"data": {
|
| 28 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 29 |
+
"model_id": "525a0eaa021f4fdebd9138f4e7c5ab65",
|
| 30 |
+
"version_major": 2,
|
| 31 |
+
"version_minor": 0
|
| 32 |
+
},
|
| 33 |
+
"text/plain": [
|
| 34 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"output_type": "display_data"
|
| 39 |
+
}
|
| 40 |
+
],
|
| 41 |
+
"source": [
|
| 42 |
+
"huggingface_hub.notebook_login()"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": 13,
|
| 48 |
+
"metadata": {
|
| 49 |
+
"pycharm": {
|
| 50 |
+
"name": "#%%\n"
|
| 51 |
+
}
|
| 52 |
+
},
|
| 53 |
+
"outputs": [
|
| 54 |
+
{
|
| 55 |
+
"name": "stdout",
|
| 56 |
+
"output_type": "stream",
|
| 57 |
+
"text": [
|
| 58 |
+
"Building generator for model `stylegan_animeface512` ...\n",
|
| 59 |
+
"Finish building generator.\n",
|
| 60 |
+
"Loading checkpoint from `checkpoints/stylegan_animeface512.pth` ...\n",
|
| 61 |
+
"Finish loading checkpoint.\n"
|
| 62 |
+
]
|
| 63 |
+
}
|
| 64 |
+
],
|
| 65 |
+
"source": [
|
| 66 |
+
"animeface_model = utils.load_generator('stylegan_animeface512')"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "code",
|
| 71 |
+
"execution_count": 5,
|
| 72 |
+
"metadata": {
|
| 73 |
+
"pycharm": {
|
| 74 |
+
"name": "#%%\n"
|
| 75 |
+
}
|
| 76 |
+
},
|
| 77 |
+
"outputs": [
|
| 78 |
+
{
|
| 79 |
+
"name": "stderr",
|
| 80 |
+
"output_type": "stream",
|
| 81 |
+
"text": [
|
| 82 |
+
"Cloning https://huggingface.co/johko/stylegan_animeface512 into local empty directory.\n"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"data": {
|
| 87 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 88 |
+
"model_id": "6e51c5ae4a504617aa0f1c1ac798ed15",
|
| 89 |
+
"version_major": 2,
|
| 90 |
+
"version_minor": 0
|
| 91 |
+
},
|
| 92 |
+
"text/plain": [
|
| 93 |
+
"Upload file pytorch_model.bin: 0%| | 32.0k/103M [00:00<?, ?B/s]"
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"output_type": "display_data"
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"name": "stderr",
|
| 101 |
+
"output_type": "stream",
|
| 102 |
+
"text": [
|
| 103 |
+
"To https://huggingface.co/johko/stylegan_animeface512\n",
|
| 104 |
+
" 750cd03..2841156 main -> main\n",
|
| 105 |
+
"\n"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"data": {
|
| 110 |
+
"text/plain": [
|
| 111 |
+
"'https://huggingface.co/johko/stylegan_animeface512/commit/2841156bad3c5a5f47f3edbf4a41880ea8fd3ad3'"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
"execution_count": 5,
|
| 115 |
+
"metadata": {},
|
| 116 |
+
"output_type": "execute_result"
|
| 117 |
+
}
|
| 118 |
+
],
|
| 119 |
+
"source": [
|
| 120 |
+
"animeface_model.push_to_hub(\"johko/stylegan_animeface512\")"
|
| 121 |
+
]
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"cell_type": "code",
|
| 125 |
+
"execution_count": 11,
|
| 126 |
+
"metadata": {
|
| 127 |
+
"pycharm": {
|
| 128 |
+
"name": "#%%\n"
|
| 129 |
+
}
|
| 130 |
+
},
|
| 131 |
+
"outputs": [
|
| 132 |
+
{
|
| 133 |
+
"name": "stdout",
|
| 134 |
+
"output_type": "stream",
|
| 135 |
+
"text": [
|
| 136 |
+
"Building generator for model `pggan_celebahq1024` ...\n",
|
| 137 |
+
"Finish building generator.\n",
|
| 138 |
+
"Loading checkpoint from `checkpoints/pggan_celebahq1024.pth` ...\n",
|
| 139 |
+
"Finish loading checkpoint.\n"
|
| 140 |
+
]
|
| 141 |
+
}
|
| 142 |
+
],
|
| 143 |
+
"source": [
|
| 144 |
+
"celebhq_model = utils.load_generator(\"pggan_celebahq1024\")"
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"cell_type": "code",
|
| 149 |
+
"execution_count": 7,
|
| 150 |
+
"metadata": {
|
| 151 |
+
"pycharm": {
|
| 152 |
+
"name": "#%%\n"
|
| 153 |
+
}
|
| 154 |
+
},
|
| 155 |
+
"outputs": [
|
| 156 |
+
{
|
| 157 |
+
"name": "stderr",
|
| 158 |
+
"output_type": "stream",
|
| 159 |
+
"text": [
|
| 160 |
+
"Cloning https://huggingface.co/johko/pggan-celebahq-1024 into local empty directory.\n"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"data": {
|
| 165 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 166 |
+
"model_id": "ef4086b23a654b079bd6a3678140c50d",
|
| 167 |
+
"version_major": 2,
|
| 168 |
+
"version_minor": 0
|
| 169 |
+
},
|
| 170 |
+
"text/plain": [
|
| 171 |
+
"Upload file pytorch_model.bin: 0%| | 32.0k/88.1M [00:00<?, ?B/s]"
|
| 172 |
+
]
|
| 173 |
+
},
|
| 174 |
+
"metadata": {},
|
| 175 |
+
"output_type": "display_data"
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"name": "stderr",
|
| 179 |
+
"output_type": "stream",
|
| 180 |
+
"text": [
|
| 181 |
+
"To https://huggingface.co/johko/pggan-celebahq-1024\n",
|
| 182 |
+
" 780695e..278449f main -> main\n",
|
| 183 |
+
"\n"
|
| 184 |
+
]
|
| 185 |
+
},
|
| 186 |
+
{
|
| 187 |
+
"data": {
|
| 188 |
+
"text/plain": [
|
| 189 |
+
"'https://huggingface.co/johko/pggan-celebahq-1024/commit/278449f8416d38a0233c980774528d32c4eee99c'"
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
"execution_count": 7,
|
| 193 |
+
"metadata": {},
|
| 194 |
+
"output_type": "execute_result"
|
| 195 |
+
}
|
| 196 |
+
],
|
| 197 |
+
"source": [
|
| 198 |
+
"celebhq_model.push_to_hub(\"johko/pggan-celebahq-1024\")"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "code",
|
| 203 |
+
"execution_count": 17,
|
| 204 |
+
"metadata": {},
|
| 205 |
+
"outputs": [
|
| 206 |
+
{
|
| 207 |
+
"name": "stdout",
|
| 208 |
+
"output_type": "stream",
|
| 209 |
+
"text": [
|
| 210 |
+
"Building generator for model `stylegan_car512` ...\n",
|
| 211 |
+
"Finish building generator.\n",
|
| 212 |
+
"Loading checkpoint from `checkpoints/stylegan_car512.pth` ...\n",
|
| 213 |
+
"Finish loading checkpoint.\n"
|
| 214 |
+
]
|
| 215 |
+
}
|
| 216 |
+
],
|
| 217 |
+
"source": [
|
| 218 |
+
"cars_model = utils.load_generator(\"stylegan_car512\")"
|
| 219 |
+
]
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"cell_type": "code",
|
| 223 |
+
"execution_count": 21,
|
| 224 |
+
"metadata": {},
|
| 225 |
+
"outputs": [
|
| 226 |
+
{
|
| 227 |
+
"name": "stdout",
|
| 228 |
+
"output_type": "stream",
|
| 229 |
+
"text": [
|
| 230 |
+
"Building generator for model `stylegan_cat256` ...\n",
|
| 231 |
+
"Finish building generator.\n",
|
| 232 |
+
"Loading checkpoint from `checkpoints/stylegan_cat256.pth` ...\n",
|
| 233 |
+
"Finish loading checkpoint.\n"
|
| 234 |
+
]
|
| 235 |
+
}
|
| 236 |
+
],
|
| 237 |
+
"source": [
|
| 238 |
+
"cats_model = utils.load_generator(\"stylegan_cat256\")"
|
| 239 |
+
]
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"cell_type": "code",
|
| 243 |
+
"execution_count": null,
|
| 244 |
+
"metadata": {},
|
| 245 |
+
"outputs": [
|
| 246 |
+
{
|
| 247 |
+
"name": "stderr",
|
| 248 |
+
"output_type": "stream",
|
| 249 |
+
"text": [
|
| 250 |
+
"Cloning https://huggingface.co/johko/stylegan_cat256 into local empty directory.\n"
|
| 251 |
+
]
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"data": {
|
| 255 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 256 |
+
"model_id": "651e9bff9c9f4555814171195e36d4d3",
|
| 257 |
+
"version_major": 2,
|
| 258 |
+
"version_minor": 0
|
| 259 |
+
},
|
| 260 |
+
"text/plain": [
|
| 261 |
+
"Upload file pytorch_model.bin: 0%| | 32.0k/100M [00:00<?, ?B/s]"
|
| 262 |
+
]
|
| 263 |
+
},
|
| 264 |
+
"metadata": {},
|
| 265 |
+
"output_type": "display_data"
|
| 266 |
+
}
|
| 267 |
+
],
|
| 268 |
+
"source": [
|
| 269 |
+
"cats_model.push_to_hub(\"johko/stylegan_cat256\")"
|
| 270 |
+
]
|
| 271 |
+
}
|
| 272 |
+
],
|
| 273 |
+
"metadata": {
|
| 274 |
+
"interpreter": {
|
| 275 |
+
"hash": "a8d699d01f596cc27ac2722fbc0550b939d217978c7e1ca888dca7ba146ee4bf"
|
| 276 |
+
},
|
| 277 |
+
"kernelspec": {
|
| 278 |
+
"display_name": "Python 3",
|
| 279 |
+
"language": "python",
|
| 280 |
+
"name": "python3"
|
| 281 |
+
},
|
| 282 |
+
"language_info": {
|
| 283 |
+
"codemirror_mode": {
|
| 284 |
+
"name": "ipython",
|
| 285 |
+
"version": 3
|
| 286 |
+
},
|
| 287 |
+
"file_extension": ".py",
|
| 288 |
+
"mimetype": "text/x-python",
|
| 289 |
+
"name": "python",
|
| 290 |
+
"nbconvert_exporter": "python",
|
| 291 |
+
"pygments_lexer": "ipython3",
|
| 292 |
+
"version": "3.9.9"
|
| 293 |
+
}
|
| 294 |
+
},
|
| 295 |
+
"nbformat": 4,
|
| 296 |
+
"nbformat_minor": 2
|
| 297 |
+
}
|
models/model_zoo.py
CHANGED
|
@@ -9,6 +9,7 @@ MODEL_ZOO = {
|
|
| 9 |
gan_type='pggan',
|
| 10 |
resolution=1024,
|
| 11 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EW_3jQ6E7xlKvCSHYrbmkQQBAB8tgIv5W5evdT6-GuXiWw?e=gRifVa&download=1',
|
|
|
|
| 12 |
),
|
| 13 |
'pggan_bedroom256': dict(
|
| 14 |
gan_type='pggan',
|
|
@@ -181,11 +182,13 @@ MODEL_ZOO = {
|
|
| 181 |
gan_type='stylegan',
|
| 182 |
resolution=256,
|
| 183 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EVjX8u9HuehLip3z0hRfIHcB7QtoFkTB7NiRDb8nrKOl2w?e=lHcp1B&download=1',
|
|
|
|
| 184 |
),
|
| 185 |
'stylegan_car512': dict(
|
| 186 |
gan_type='stylegan',
|
| 187 |
resolution=512,
|
| 188 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcRJNNzzUzJGjI2X53S9HjkBhXkKT5JRd6Q3IIhCY1AyRw?e=FvMRNj&download=1',
|
|
|
|
| 189 |
),
|
| 190 |
|
| 191 |
# StyleGAN ours.
|
|
@@ -260,6 +263,7 @@ MODEL_ZOO = {
|
|
| 260 |
gan_type='stylegan',
|
| 261 |
resolution=512,
|
| 262 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWDWflY6lBpGgX0CGQpd2Z4B5wTEVamTOA9JRYne7zdCvA?e=tOzgYA&download=1',
|
|
|
|
| 263 |
),
|
| 264 |
'stylegan_animeportrait512': dict(
|
| 265 |
gan_type='stylegan',
|
|
@@ -296,15 +300,8 @@ MODEL_ZOO = {
|
|
| 296 |
'stylegan2_car512': dict(
|
| 297 |
gan_type='stylegan2',
|
| 298 |
resolution=512,
|
| 299 |
-
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EYSnUsxU8KJFuMHhZm-JLWoB0nHxdlbrLHNZ_Qkoe3b9LA?e=Ycjp5A&download=1'
|
| 300 |
),
|
| 301 |
-
|
| 302 |
-
#huggingface models
|
| 303 |
-
'akhaliq/OneshotCLIP-stylegan2-ffhq' : dict(
|
| 304 |
-
gan_type='stylegan2',
|
| 305 |
-
resolution=512,
|
| 306 |
-
url='akhaliq/OneshotCLIP-stylegan2-ffhq',
|
| 307 |
-
)
|
| 308 |
}
|
| 309 |
|
| 310 |
# pylint: enable=line-too-long
|
|
|
|
| 9 |
gan_type='pggan',
|
| 10 |
resolution=1024,
|
| 11 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EW_3jQ6E7xlKvCSHYrbmkQQBAB8tgIv5W5evdT6-GuXiWw?e=gRifVa&download=1',
|
| 12 |
+
hf_hub_repo='huggan/pggan-celebahq-1024'
|
| 13 |
),
|
| 14 |
'pggan_bedroom256': dict(
|
| 15 |
gan_type='pggan',
|
|
|
|
| 182 |
gan_type='stylegan',
|
| 183 |
resolution=256,
|
| 184 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EVjX8u9HuehLip3z0hRfIHcB7QtoFkTB7NiRDb8nrKOl2w?e=lHcp1B&download=1',
|
| 185 |
+
hf_hub_repo="huggan/stylegan_cat256"
|
| 186 |
),
|
| 187 |
'stylegan_car512': dict(
|
| 188 |
gan_type='stylegan',
|
| 189 |
resolution=512,
|
| 190 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcRJNNzzUzJGjI2X53S9HjkBhXkKT5JRd6Q3IIhCY1AyRw?e=FvMRNj&download=1',
|
| 191 |
+
hf_hub_repo="huggan/stylegan_car512"
|
| 192 |
),
|
| 193 |
|
| 194 |
# StyleGAN ours.
|
|
|
|
| 263 |
gan_type='stylegan',
|
| 264 |
resolution=512,
|
| 265 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWDWflY6lBpGgX0CGQpd2Z4B5wTEVamTOA9JRYne7zdCvA?e=tOzgYA&download=1',
|
| 266 |
+
hf_hub_repo='huggan/stylegan_animeface512'
|
| 267 |
),
|
| 268 |
'stylegan_animeportrait512': dict(
|
| 269 |
gan_type='stylegan',
|
|
|
|
| 300 |
'stylegan2_car512': dict(
|
| 301 |
gan_type='stylegan2',
|
| 302 |
resolution=512,
|
| 303 |
+
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EYSnUsxU8KJFuMHhZm-JLWoB0nHxdlbrLHNZ_Qkoe3b9LA?e=Ycjp5A&download=1'
|
| 304 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
}
|
| 306 |
|
| 307 |
# pylint: enable=line-too-long
|
models/pggan_generator.py
CHANGED
|
@@ -6,6 +6,7 @@ Paper: https://arxiv.org/pdf/1710.10196.pdf
|
|
| 6 |
Official TensorFlow implementation:
|
| 7 |
https://github.com/tkarras/progressive_growing_of_gans
|
| 8 |
"""
|
|
|
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
|
|
@@ -13,6 +14,8 @@ import torch
|
|
| 13 |
import torch.nn as nn
|
| 14 |
import torch.nn.functional as F
|
| 15 |
|
|
|
|
|
|
|
| 16 |
__all__ = ['PGGANGenerator']
|
| 17 |
|
| 18 |
# Resolutions allowed.
|
|
@@ -25,7 +28,7 @@ _INIT_RES = 4
|
|
| 25 |
_WSCALE_GAIN = np.sqrt(2.0)
|
| 26 |
|
| 27 |
|
| 28 |
-
class PGGANGenerator(nn.Module):
|
| 29 |
"""Defines the generator network in PGGAN.
|
| 30 |
|
| 31 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
|
@@ -57,7 +60,8 @@ class PGGANGenerator(nn.Module):
|
|
| 57 |
fused_scale=False,
|
| 58 |
use_wscale=True,
|
| 59 |
fmaps_base=16 << 10,
|
| 60 |
-
fmaps_max=512
|
|
|
|
| 61 |
"""Initializes with basic settings.
|
| 62 |
|
| 63 |
Raises:
|
|
@@ -81,6 +85,8 @@ class PGGANGenerator(nn.Module):
|
|
| 81 |
self.use_wscale = use_wscale
|
| 82 |
self.fmaps_base = fmaps_base
|
| 83 |
self.fmaps_max = fmaps_max
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# Number of convolutional layers.
|
| 86 |
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
|
@@ -202,6 +208,46 @@ class PGGANGenerator(nn.Module):
|
|
| 202 |
}
|
| 203 |
return results
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
class PixelNormLayer(nn.Module):
|
| 207 |
"""Implements pixel-wise feature vector normalization layer."""
|
|
|
|
| 6 |
Official TensorFlow implementation:
|
| 7 |
https://github.com/tkarras/progressive_growing_of_gans
|
| 8 |
"""
|
| 9 |
+
import os
|
| 10 |
|
| 11 |
import numpy as np
|
| 12 |
|
|
|
|
| 14 |
import torch.nn as nn
|
| 15 |
import torch.nn.functional as F
|
| 16 |
|
| 17 |
+
from huggingface_hub import PyTorchModelHubMixin, PYTORCH_WEIGHTS_NAME, hf_hub_download
|
| 18 |
+
|
| 19 |
__all__ = ['PGGANGenerator']
|
| 20 |
|
| 21 |
# Resolutions allowed.
|
|
|
|
| 28 |
_WSCALE_GAIN = np.sqrt(2.0)
|
| 29 |
|
| 30 |
|
| 31 |
+
class PGGANGenerator(nn.Module, PyTorchModelHubMixin):
|
| 32 |
"""Defines the generator network in PGGAN.
|
| 33 |
|
| 34 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
|
|
|
| 60 |
fused_scale=False,
|
| 61 |
use_wscale=True,
|
| 62 |
fmaps_base=16 << 10,
|
| 63 |
+
fmaps_max=512,
|
| 64 |
+
**kwargs):
|
| 65 |
"""Initializes with basic settings.
|
| 66 |
|
| 67 |
Raises:
|
|
|
|
| 85 |
self.use_wscale = use_wscale
|
| 86 |
self.fmaps_base = fmaps_base
|
| 87 |
self.fmaps_max = fmaps_max
|
| 88 |
+
|
| 89 |
+
self.config = kwargs.pop("config", None)
|
| 90 |
|
| 91 |
# Number of convolutional layers.
|
| 92 |
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
|
|
|
| 208 |
}
|
| 209 |
return results
|
| 210 |
|
| 211 |
+
@classmethod
|
| 212 |
+
def _from_pretrained(
|
| 213 |
+
cls,
|
| 214 |
+
model_id,
|
| 215 |
+
revision,
|
| 216 |
+
cache_dir,
|
| 217 |
+
force_download,
|
| 218 |
+
proxies,
|
| 219 |
+
resume_download,
|
| 220 |
+
local_files_only,
|
| 221 |
+
use_auth_token,
|
| 222 |
+
map_location="cpu",
|
| 223 |
+
strict=False,
|
| 224 |
+
**model_kwargs,
|
| 225 |
+
):
|
| 226 |
+
"""
|
| 227 |
+
Overwrite this method in case you wish to initialize your model in a
|
| 228 |
+
different way.
|
| 229 |
+
"""
|
| 230 |
+
map_location = torch.device(map_location)
|
| 231 |
+
|
| 232 |
+
if os.path.isdir(model_id):
|
| 233 |
+
print("Loading weights from local directory")
|
| 234 |
+
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
| 235 |
+
else:
|
| 236 |
+
model_file = hf_hub_download(
|
| 237 |
+
repo_id=model_id,
|
| 238 |
+
filename=PYTORCH_WEIGHTS_NAME,
|
| 239 |
+
revision=revision,
|
| 240 |
+
cache_dir=cache_dir,
|
| 241 |
+
force_download=force_download,
|
| 242 |
+
proxies=proxies,
|
| 243 |
+
resume_download=resume_download,
|
| 244 |
+
use_auth_token=use_auth_token,
|
| 245 |
+
local_files_only=local_files_only,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
pretrained = torch.load(model_file, map_location=map_location)
|
| 249 |
+
return pretrained
|
| 250 |
+
|
| 251 |
|
| 252 |
class PixelNormLayer(nn.Module):
|
| 253 |
"""Implements pixel-wise feature vector normalization layer."""
|
models/stylegan2_generator.py
CHANGED
|
@@ -9,12 +9,14 @@ Paper: https://arxiv.org/pdf/1912.04958.pdf
|
|
| 9 |
|
| 10 |
Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
|
| 11 |
"""
|
|
|
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
|
| 15 |
import torch
|
| 16 |
import torch.nn as nn
|
| 17 |
import torch.nn.functional as F
|
|
|
|
| 18 |
|
| 19 |
from .sync_op import all_gather
|
| 20 |
|
|
@@ -33,7 +35,7 @@ _ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
|
|
| 33 |
_WSCALE_GAIN = 1.0
|
| 34 |
|
| 35 |
|
| 36 |
-
class StyleGAN2Generator(nn.Module):
|
| 37 |
"""Defines the generator network in StyleGAN2.
|
| 38 |
|
| 39 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
|
@@ -88,7 +90,8 @@ class StyleGAN2Generator(nn.Module):
|
|
| 88 |
demodulate=True,
|
| 89 |
use_wscale=True,
|
| 90 |
fmaps_base=32 << 10,
|
| 91 |
-
fmaps_max=512
|
|
|
|
| 92 |
"""Initializes with basic settings.
|
| 93 |
|
| 94 |
Raises:
|
|
@@ -195,6 +198,45 @@ class StyleGAN2Generator(nn.Module):
|
|
| 195 |
|
| 196 |
return {**mapping_results, **synthesis_results}
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
class MappingModule(nn.Module):
|
| 200 |
"""Implements the latent space mapping module.
|
|
|
|
| 9 |
|
| 10 |
Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
|
| 11 |
"""
|
| 12 |
+
import os
|
| 13 |
|
| 14 |
import numpy as np
|
| 15 |
|
| 16 |
import torch
|
| 17 |
import torch.nn as nn
|
| 18 |
import torch.nn.functional as F
|
| 19 |
+
from huggingface_hub import PYTORCH_WEIGHTS_NAME, hf_hub_download, PyTorchModelHubMixin
|
| 20 |
|
| 21 |
from .sync_op import all_gather
|
| 22 |
|
|
|
|
| 35 |
_WSCALE_GAIN = 1.0
|
| 36 |
|
| 37 |
|
| 38 |
+
class StyleGAN2Generator(nn.Module, PyTorchModelHubMixin):
|
| 39 |
"""Defines the generator network in StyleGAN2.
|
| 40 |
|
| 41 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
|
|
|
| 90 |
demodulate=True,
|
| 91 |
use_wscale=True,
|
| 92 |
fmaps_base=32 << 10,
|
| 93 |
+
fmaps_max=512,
|
| 94 |
+
**kwargs):
|
| 95 |
"""Initializes with basic settings.
|
| 96 |
|
| 97 |
Raises:
|
|
|
|
| 198 |
|
| 199 |
return {**mapping_results, **synthesis_results}
|
| 200 |
|
| 201 |
+
@classmethod
|
| 202 |
+
def _from_pretrained(
|
| 203 |
+
cls,
|
| 204 |
+
model_id,
|
| 205 |
+
revision,
|
| 206 |
+
cache_dir,
|
| 207 |
+
force_download,
|
| 208 |
+
proxies,
|
| 209 |
+
resume_download,
|
| 210 |
+
local_files_only,
|
| 211 |
+
use_auth_token,
|
| 212 |
+
map_location="cpu",
|
| 213 |
+
strict=False,
|
| 214 |
+
**model_kwargs,
|
| 215 |
+
):
|
| 216 |
+
"""
|
| 217 |
+
Overwrite this method in case you wish to initialize your model in a
|
| 218 |
+
different way.
|
| 219 |
+
"""
|
| 220 |
+
map_location = torch.device(map_location)
|
| 221 |
+
|
| 222 |
+
if os.path.isdir(model_id):
|
| 223 |
+
print("Loading weights from local directory")
|
| 224 |
+
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
| 225 |
+
else:
|
| 226 |
+
model_file = hf_hub_download(
|
| 227 |
+
repo_id=model_id,
|
| 228 |
+
filename="stylegan2-ffhq-config-f.pt",
|
| 229 |
+
revision=revision,
|
| 230 |
+
cache_dir=cache_dir,
|
| 231 |
+
force_download=force_download,
|
| 232 |
+
proxies=proxies,
|
| 233 |
+
resume_download=resume_download,
|
| 234 |
+
use_auth_token=use_auth_token,
|
| 235 |
+
local_files_only=local_files_only,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
pretrained = torch.load(model_file, map_location=map_location)
|
| 239 |
+
return pretrained
|
| 240 |
|
| 241 |
class MappingModule(nn.Module):
|
| 242 |
"""Implements the latent space mapping module.
|
models/stylegan_generator.py
CHANGED
|
@@ -5,6 +5,7 @@ Paper: https://arxiv.org/pdf/1812.04948.pdf
|
|
| 5 |
|
| 6 |
Official TensorFlow implementation: https://github.com/NVlabs/stylegan
|
| 7 |
"""
|
|
|
|
| 8 |
|
| 9 |
import numpy as np
|
| 10 |
|
|
@@ -14,6 +15,8 @@ import torch.nn.functional as F
|
|
| 14 |
|
| 15 |
from .sync_op import all_gather
|
| 16 |
|
|
|
|
|
|
|
| 17 |
__all__ = ['StyleGANGenerator']
|
| 18 |
|
| 19 |
# Resolutions allowed.
|
|
@@ -33,7 +36,7 @@ _WSCALE_GAIN = np.sqrt(2.0)
|
|
| 33 |
_STYLEMOD_WSCALE_GAIN = 1.0
|
| 34 |
|
| 35 |
|
| 36 |
-
class StyleGANGenerator(nn.Module):
|
| 37 |
"""Defines the generator network in StyleGAN.
|
| 38 |
|
| 39 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
|
@@ -83,7 +86,8 @@ class StyleGANGenerator(nn.Module):
|
|
| 83 |
fused_scale='auto',
|
| 84 |
use_wscale=True,
|
| 85 |
fmaps_base=16 << 10,
|
| 86 |
-
fmaps_max=512
|
|
|
|
| 87 |
"""Initializes with basic settings.
|
| 88 |
|
| 89 |
Raises:
|
|
@@ -115,6 +119,9 @@ class StyleGANGenerator(nn.Module):
|
|
| 115 |
self.use_wscale = use_wscale
|
| 116 |
self.fmaps_base = fmaps_base
|
| 117 |
self.fmaps_max = fmaps_max
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
|
| 120 |
|
|
@@ -188,6 +195,46 @@ class StyleGANGenerator(nn.Module):
|
|
| 188 |
|
| 189 |
return {**mapping_results, **synthesis_results}
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
class MappingModule(nn.Module):
|
| 193 |
"""Implements the latent space mapping module.
|
|
|
|
| 5 |
|
| 6 |
Official TensorFlow implementation: https://github.com/NVlabs/stylegan
|
| 7 |
"""
|
| 8 |
+
import os
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
|
|
|
|
| 15 |
|
| 16 |
from .sync_op import all_gather
|
| 17 |
|
| 18 |
+
from huggingface_hub import PyTorchModelHubMixin, PYTORCH_WEIGHTS_NAME, hf_hub_download
|
| 19 |
+
|
| 20 |
__all__ = ['StyleGANGenerator']
|
| 21 |
|
| 22 |
# Resolutions allowed.
|
|
|
|
| 36 |
_STYLEMOD_WSCALE_GAIN = 1.0
|
| 37 |
|
| 38 |
|
| 39 |
+
class StyleGANGenerator(nn.Module, PyTorchModelHubMixin):
|
| 40 |
"""Defines the generator network in StyleGAN.
|
| 41 |
|
| 42 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
|
|
|
| 86 |
fused_scale='auto',
|
| 87 |
use_wscale=True,
|
| 88 |
fmaps_base=16 << 10,
|
| 89 |
+
fmaps_max=512,
|
| 90 |
+
**kwargs):
|
| 91 |
"""Initializes with basic settings.
|
| 92 |
|
| 93 |
Raises:
|
|
|
|
| 119 |
self.use_wscale = use_wscale
|
| 120 |
self.fmaps_base = fmaps_base
|
| 121 |
self.fmaps_max = fmaps_max
|
| 122 |
+
|
| 123 |
+
self.config = kwargs.pop("config", None)
|
| 124 |
+
|
| 125 |
|
| 126 |
self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
|
| 127 |
|
|
|
|
| 195 |
|
| 196 |
return {**mapping_results, **synthesis_results}
|
| 197 |
|
| 198 |
+
@classmethod
|
| 199 |
+
def _from_pretrained(
|
| 200 |
+
cls,
|
| 201 |
+
model_id,
|
| 202 |
+
revision,
|
| 203 |
+
cache_dir,
|
| 204 |
+
force_download,
|
| 205 |
+
proxies,
|
| 206 |
+
resume_download,
|
| 207 |
+
local_files_only,
|
| 208 |
+
use_auth_token,
|
| 209 |
+
map_location="cpu",
|
| 210 |
+
strict=False,
|
| 211 |
+
**model_kwargs,
|
| 212 |
+
):
|
| 213 |
+
"""
|
| 214 |
+
Overwrite this method in case you wish to initialize your model in a
|
| 215 |
+
different way.
|
| 216 |
+
"""
|
| 217 |
+
map_location = torch.device(map_location)
|
| 218 |
+
|
| 219 |
+
if os.path.isdir(model_id):
|
| 220 |
+
print("Loading weights from local directory")
|
| 221 |
+
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
| 222 |
+
else:
|
| 223 |
+
model_file = hf_hub_download(
|
| 224 |
+
repo_id=model_id,
|
| 225 |
+
filename=PYTORCH_WEIGHTS_NAME,
|
| 226 |
+
revision=revision,
|
| 227 |
+
cache_dir=cache_dir,
|
| 228 |
+
force_download=force_download,
|
| 229 |
+
proxies=proxies,
|
| 230 |
+
resume_download=resume_download,
|
| 231 |
+
use_auth_token=use_auth_token,
|
| 232 |
+
local_files_only=local_files_only,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
pretrained = torch.load(model_file, map_location=map_location)
|
| 236 |
+
return pretrained
|
| 237 |
+
|
| 238 |
|
| 239 |
class MappingModule(nn.Module):
|
| 240 |
"""Implements the latent space mapping module.
|
utils.py
CHANGED
|
@@ -50,7 +50,7 @@ def postprocess(images, min_val=-1.0, max_val=1.0):
|
|
| 50 |
return images
|
| 51 |
|
| 52 |
|
| 53 |
-
def load_generator(model_name):
|
| 54 |
"""Loads pre-trained generator.
|
| 55 |
|
| 56 |
Args:
|
|
@@ -74,19 +74,25 @@ def load_generator(model_name):
|
|
| 74 |
generator = build_generator(**model_config)
|
| 75 |
print(f'Finish building generator.')
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
if not os.path.exists(checkpoint_path):
|
| 82 |
-
print(f' Downloading checkpoint from `{url}` ...')
|
| 83 |
-
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
|
| 84 |
-
print(f' Finish downloading checkpoint.')
|
| 85 |
-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 86 |
-
if 'generator_smooth' in checkpoint:
|
| 87 |
-
generator.load_state_dict(checkpoint['generator_smooth'])
|
| 88 |
else:
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
#generator = generator.cuda()
|
| 91 |
generator.eval()
|
| 92 |
print(f'Finish loading checkpoint.')
|
|
|
|
| 50 |
return images
|
| 51 |
|
| 52 |
|
| 53 |
+
def load_generator(model_name, from_hf_hub=False):
|
| 54 |
"""Loads pre-trained generator.
|
| 55 |
|
| 56 |
Args:
|
|
|
|
| 74 |
generator = build_generator(**model_config)
|
| 75 |
print(f'Finish building generator.')
|
| 76 |
|
| 77 |
+
if from_hf_hub and "hf_hub_repo" in model_config.keys():
|
| 78 |
+
checkpoint = generator.from_pretrained(model_config["hf_hub_repo"])
|
| 79 |
+
generator.load_state_dict(checkpoint)
|
| 80 |
+
print("loaded from hf_hub")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
else:
|
| 82 |
+
# Load pre-trained weights.
|
| 83 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 84 |
+
checkpoint_path = os.path.join(CHECKPOINT_DIR, model_name + '.pth')
|
| 85 |
+
print(f'Loading checkpoint from `{checkpoint_path}` ...')
|
| 86 |
+
if not os.path.exists(checkpoint_path):
|
| 87 |
+
print(f' Downloading checkpoint from `{url}` ...')
|
| 88 |
+
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
|
| 89 |
+
print(f' Finish downloading checkpoint.')
|
| 90 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 91 |
+
|
| 92 |
+
if 'generator_smooth' in checkpoint:
|
| 93 |
+
generator.load_state_dict(checkpoint['generator_smooth'])
|
| 94 |
+
else:
|
| 95 |
+
generator.load_state_dict(checkpoint['generator'])
|
| 96 |
#generator = generator.cuda()
|
| 97 |
generator.eval()
|
| 98 |
print(f'Finish loading checkpoint.')
|