Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
·
4b8c3a8
1
Parent(s):
eb912a1
* JIT outside the loop.
Browse filesMy tests yesterday were wrong: there is a noticeable performance
improvement doing it this way. Even so, JIT runs twice, we could cut
times in half (for this test) if we could make it run once.
encoding/vqgan-jax-encoding.ipynb
CHANGED
|
@@ -363,20 +363,21 @@
|
|
| 363 |
},
|
| 364 |
{
|
| 365 |
"cell_type": "code",
|
| 366 |
-
"execution_count":
|
| 367 |
-
"id": "
|
| 368 |
"metadata": {},
|
| 369 |
"outputs": [],
|
| 370 |
"source": [
|
| 371 |
"def encode(model, batch):\n",
|
|
|
|
| 372 |
" _, indices = model.encode(batch)\n",
|
| 373 |
" return indices"
|
| 374 |
]
|
| 375 |
},
|
| 376 |
{
|
| 377 |
"cell_type": "code",
|
| 378 |
-
"execution_count":
|
| 379 |
-
"id": "
|
| 380 |
"metadata": {},
|
| 381 |
"outputs": [],
|
| 382 |
"source": [
|
|
@@ -969,15 +970,15 @@
|
|
| 969 |
},
|
| 970 |
{
|
| 971 |
"cell_type": "markdown",
|
| 972 |
-
"id": "
|
| 973 |
"metadata": {},
|
| 974 |
"source": [
|
| 975 |
-
"It works! Let's wrap it and run the whole process on the 10k images subset."
|
| 976 |
]
|
| 977 |
},
|
| 978 |
{
|
| 979 |
"cell_type": "markdown",
|
| 980 |
-
"id": "
|
| 981 |
"metadata": {},
|
| 982 |
"source": [
|
| 983 |
"## 10k encoding"
|
|
@@ -993,8 +994,18 @@
|
|
| 993 |
},
|
| 994 |
{
|
| 995 |
"cell_type": "code",
|
| 996 |
-
"execution_count":
|
| 997 |
-
"id": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 998 |
"metadata": {},
|
| 999 |
"outputs": [],
|
| 1000 |
"source": [
|
|
@@ -1004,10 +1015,11 @@
|
|
| 1004 |
" superbatches = superbatch_generator(dataloader)\n",
|
| 1005 |
" \n",
|
| 1006 |
" # TODO: save to disk as we go, do not accumulate everything in RAM\n",
|
| 1007 |
-
"#
|
|
|
|
| 1008 |
" results = None\n",
|
| 1009 |
" for superbatch in tqdm(superbatches):\n",
|
| 1010 |
-
" encoded =
|
| 1011 |
" encoded = encoded.reshape(encoded.shape[0] * encoded.shape[1], -1)\n",
|
| 1012 |
" results = np.concatenate((results, encoded), axis=0) if results is not None else encoded\n",
|
| 1013 |
" return results"
|
|
@@ -1015,15 +1027,15 @@
|
|
| 1015 |
},
|
| 1016 |
{
|
| 1017 |
"cell_type": "code",
|
| 1018 |
-
"execution_count":
|
| 1019 |
-
"id": "
|
| 1020 |
"metadata": {},
|
| 1021 |
"outputs": [
|
| 1022 |
{
|
| 1023 |
"name": "stderr",
|
| 1024 |
"output_type": "stream",
|
| 1025 |
"text": [
|
| 1026 |
-
"16it [
|
| 1027 |
]
|
| 1028 |
}
|
| 1029 |
],
|
|
|
|
| 363 |
},
|
| 364 |
{
|
| 365 |
"cell_type": "code",
|
| 366 |
+
"execution_count": 76,
|
| 367 |
+
"id": "fd26cdce",
|
| 368 |
"metadata": {},
|
| 369 |
"outputs": [],
|
| 370 |
"source": [
|
| 371 |
"def encode(model, batch):\n",
|
| 372 |
+
"# print(\"jitting encode function\")\n",
|
| 373 |
" _, indices = model.encode(batch)\n",
|
| 374 |
" return indices"
|
| 375 |
]
|
| 376 |
},
|
| 377 |
{
|
| 378 |
"cell_type": "code",
|
| 379 |
+
"execution_count": 18,
|
| 380 |
+
"id": "c49181e1",
|
| 381 |
"metadata": {},
|
| 382 |
"outputs": [],
|
| 383 |
"source": [
|
|
|
|
| 970 |
},
|
| 971 |
{
|
| 972 |
"cell_type": "markdown",
|
| 973 |
+
"id": "48896d5f",
|
| 974 |
"metadata": {},
|
| 975 |
"source": [
|
| 976 |
+
"It works! Let's wrap it up and run the whole process on the 10k images subset."
|
| 977 |
]
|
| 978 |
},
|
| 979 |
{
|
| 980 |
"cell_type": "markdown",
|
| 981 |
+
"id": "029d35d9",
|
| 982 |
"metadata": {},
|
| 983 |
"source": [
|
| 984 |
"## 10k encoding"
|
|
|
|
| 994 |
},
|
| 995 |
{
|
| 996 |
"cell_type": "code",
|
| 997 |
+
"execution_count": 45,
|
| 998 |
+
"id": "04b1568b",
|
| 999 |
+
"metadata": {},
|
| 1000 |
+
"outputs": [],
|
| 1001 |
+
"source": [
|
| 1002 |
+
"from functools import partial"
|
| 1003 |
+
]
|
| 1004 |
+
},
|
| 1005 |
+
{
|
| 1006 |
+
"cell_type": "code",
|
| 1007 |
+
"execution_count": 78,
|
| 1008 |
+
"id": "bfa3073b",
|
| 1009 |
"metadata": {},
|
| 1010 |
"outputs": [],
|
| 1011 |
"source": [
|
|
|
|
| 1015 |
" superbatches = superbatch_generator(dataloader)\n",
|
| 1016 |
" \n",
|
| 1017 |
" # TODO: save to disk as we go, do not accumulate everything in RAM\n",
|
| 1018 |
+
"# p_encoder = pmap(partial(encode, model), in_axes=(0,), donate_argnums=(0))\n",
|
| 1019 |
+
" p_encoder = pmap(lambda batch: encode(model, batch))\n",
|
| 1020 |
" results = None\n",
|
| 1021 |
" for superbatch in tqdm(superbatches):\n",
|
| 1022 |
+
" encoded = p_encoder(superbatch.numpy())\n",
|
| 1023 |
" encoded = encoded.reshape(encoded.shape[0] * encoded.shape[1], -1)\n",
|
| 1024 |
" results = np.concatenate((results, encoded), axis=0) if results is not None else encoded\n",
|
| 1025 |
" return results"
|
|
|
|
| 1027 |
},
|
| 1028 |
{
|
| 1029 |
"cell_type": "code",
|
| 1030 |
+
"execution_count": 79,
|
| 1031 |
+
"id": "d8d4da18",
|
| 1032 |
"metadata": {},
|
| 1033 |
"outputs": [
|
| 1034 |
{
|
| 1035 |
"name": "stderr",
|
| 1036 |
"output_type": "stream",
|
| 1037 |
"text": [
|
| 1038 |
+
"16it [00:41, 2.61s/it]\n"
|
| 1039 |
]
|
| 1040 |
}
|
| 1041 |
],
|