Spaces:
Running
Running
fix: pmap clip32
Browse files
dev/inference/wandb-backend.ipynb
CHANGED
|
@@ -36,7 +36,8 @@
|
|
| 36 |
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
|
| 37 |
"normalize_text = True\n",
|
| 38 |
"latest_only = False # log only latest or all versions\n",
|
| 39 |
-
"suffix = '_1' # mainly for duplicate inference runs with a deleted version"
|
|
|
|
| 40 |
]
|
| 41 |
},
|
| 42 |
{
|
|
@@ -51,7 +52,8 @@
|
|
| 51 |
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
|
| 52 |
"normalize_text = False\n",
|
| 53 |
"latest_only = True # log only latest or all versions\n",
|
| 54 |
-
"suffix = '_2' # mainly for duplicate inference runs with a deleted version"
|
|
|
|
| 55 |
]
|
| 56 |
},
|
| 57 |
{
|
|
@@ -82,7 +84,12 @@
|
|
| 82 |
"clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
| 83 |
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
| 84 |
"clip_params = replicate(clip.params)\n",
|
| 85 |
-
"vqgan_params = replicate(vqgan.params)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
]
|
| 87 |
},
|
| 88 |
{
|
|
@@ -98,8 +105,14 @@
|
|
| 98 |
"\n",
|
| 99 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 100 |
"def p_clip(inputs):\n",
|
| 101 |
-
" logits = clip(**inputs).logits_per_image\n",
|
| 102 |
-
" return logits"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
]
|
| 104 |
},
|
| 105 |
{
|
|
@@ -158,7 +171,7 @@
|
|
| 158 |
"# retrieve inference run details\n",
|
| 159 |
"def get_last_inference_version(run_id):\n",
|
| 160 |
" try:\n",
|
| 161 |
-
" inference_run = api.run(f'dalle-mini/dalle-mini/
|
| 162 |
" return inference_run.summary.get('version', None)\n",
|
| 163 |
" except:\n",
|
| 164 |
" return None"
|
|
@@ -215,6 +228,8 @@
|
|
| 215 |
" print(f'Processing artifact: {artifact.name}')\n",
|
| 216 |
" version = int(artifact.version[1:])\n",
|
| 217 |
" results = []\n",
|
|
|
|
|
|
|
| 218 |
" columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
|
| 219 |
" \n",
|
| 220 |
" if latest_only:\n",
|
|
@@ -232,7 +247,7 @@
|
|
| 232 |
"\n",
|
| 233 |
" # start/resume corresponding run\n",
|
| 234 |
" if run is None:\n",
|
| 235 |
-
" run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'
|
| 236 |
"\n",
|
| 237 |
" # work in temporary directory\n",
|
| 238 |
" with tempfile.TemporaryDirectory() as tmp:\n",
|
|
@@ -283,7 +298,6 @@
|
|
| 283 |
" logits = logits.reshape(-1, num_images)\n",
|
| 284 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
| 285 |
" logits = jax.device_get(logits)\n",
|
| 286 |
-
"\n",
|
| 287 |
" # add to results table\n",
|
| 288 |
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
| 289 |
" if sample == padding_item: continue\n",
|
|
@@ -291,11 +305,68 @@
|
|
| 291 |
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
| 292 |
" top_scores = [scores[x] for x in idx]\n",
|
| 293 |
" results.append([sample] + top_images + top_scores)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
"\n",
|
| 295 |
" # log results\n",
|
| 296 |
" table = wandb.Table(columns=columns, data=results)\n",
|
| 297 |
" run.log({'Samples': table, 'version': version})\n",
|
| 298 |
-
" wandb.finish()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
]
|
| 300 |
},
|
| 301 |
{
|
|
@@ -314,12 +385,10 @@
|
|
| 314 |
{
|
| 315 |
"cell_type": "code",
|
| 316 |
"execution_count": null,
|
| 317 |
-
"id": "
|
| 318 |
"metadata": {},
|
| 319 |
"outputs": [],
|
| 320 |
-
"source": [
|
| 321 |
-
"wandb.finish()"
|
| 322 |
-
]
|
| 323 |
}
|
| 324 |
],
|
| 325 |
"metadata": {
|
|
|
|
| 36 |
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
|
| 37 |
"normalize_text = True\n",
|
| 38 |
"latest_only = False # log only latest or all versions\n",
|
| 39 |
+
"suffix = '_1' # mainly for duplicate inference runs with a deleted version\n",
|
| 40 |
+
"add_clip_32 = False"
|
| 41 |
]
|
| 42 |
},
|
| 43 |
{
|
|
|
|
| 52 |
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
|
| 53 |
"normalize_text = False\n",
|
| 54 |
"latest_only = True # log only latest or all versions\n",
|
| 55 |
+
"suffix = '_2' # mainly for duplicate inference runs with a deleted version\n",
|
| 56 |
+
"add_clip_32 = True"
|
| 57 |
]
|
| 58 |
},
|
| 59 |
{
|
|
|
|
| 84 |
"clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
| 85 |
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
| 86 |
"clip_params = replicate(clip.params)\n",
|
| 87 |
+
"vqgan_params = replicate(vqgan.params)\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"if add_clip_32:\n",
|
| 90 |
+
" clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
| 91 |
+
" processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
| 92 |
+
" clip32_params = replicate(clip32.params)"
|
| 93 |
]
|
| 94 |
},
|
| 95 |
{
|
|
|
|
| 105 |
"\n",
|
| 106 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 107 |
"def p_clip(inputs):\n",
|
| 108 |
+
" logits = clip(params=clip_params, **inputs).logits_per_image\n",
|
| 109 |
+
" return logits\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"if add_clip_32:\n",
|
| 112 |
+
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
| 113 |
+
" def p_clip32(inputs):\n",
|
| 114 |
+
" logits = clip32(params=clip32_params, **inputs).logits_per_image\n",
|
| 115 |
+
" return logits"
|
| 116 |
]
|
| 117 |
},
|
| 118 |
{
|
|
|
|
| 171 |
"# retrieve inference run details\n",
|
| 172 |
"def get_last_inference_version(run_id):\n",
|
| 173 |
" try:\n",
|
| 174 |
+
" inference_run = api.run(f'dalle-mini/dalle-mini/{run_id}-clip16{suffix}')\n",
|
| 175 |
" return inference_run.summary.get('version', None)\n",
|
| 176 |
" except:\n",
|
| 177 |
" return None"
|
|
|
|
| 228 |
" print(f'Processing artifact: {artifact.name}')\n",
|
| 229 |
" version = int(artifact.version[1:])\n",
|
| 230 |
" results = []\n",
|
| 231 |
+
" if add_clip_32:\n",
|
| 232 |
+
" results32 = []\n",
|
| 233 |
" columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
|
| 234 |
" \n",
|
| 235 |
" if latest_only:\n",
|
|
|
|
| 247 |
"\n",
|
| 248 |
" # start/resume corresponding run\n",
|
| 249 |
" if run is None:\n",
|
| 250 |
+
" run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip16{suffix}', resume='allow')\n",
|
| 251 |
"\n",
|
| 252 |
" # work in temporary directory\n",
|
| 253 |
" with tempfile.TemporaryDirectory() as tmp:\n",
|
|
|
|
| 298 |
" logits = logits.reshape(-1, num_images)\n",
|
| 299 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
| 300 |
" logits = jax.device_get(logits)\n",
|
|
|
|
| 301 |
" # add to results table\n",
|
| 302 |
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
| 303 |
" if sample == padding_item: continue\n",
|
|
|
|
| 305 |
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
| 306 |
" top_scores = [scores[x] for x in idx]\n",
|
| 307 |
" results.append([sample] + top_images + top_scores)\n",
|
| 308 |
+
" \n",
|
| 309 |
+
" # get clip 32 scores - TODO: this should be refactored as it is same code as above\n",
|
| 310 |
+
" if add_clip_32:\n",
|
| 311 |
+
" print('Calculating CLIP 32 scores')\n",
|
| 312 |
+
" clip_inputs = processor32(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
| 313 |
+
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
| 314 |
+
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
| 315 |
+
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
| 316 |
+
" clip_inputs = shard(clip_inputs)\n",
|
| 317 |
+
" logits = p_clip32(clip_inputs)\n",
|
| 318 |
+
" logits = logits.reshape(-1, num_images)\n",
|
| 319 |
+
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
| 320 |
+
" logits = jax.device_get(logits)\n",
|
| 321 |
+
" # add to results table\n",
|
| 322 |
+
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
| 323 |
+
" if sample == padding_item: continue\n",
|
| 324 |
+
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
| 325 |
+
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
| 326 |
+
" top_scores = [scores[x] for x in idx]\n",
|
| 327 |
+
" results32.append([sample] + top_images + top_scores)\n",
|
| 328 |
"\n",
|
| 329 |
" # log results\n",
|
| 330 |
" table = wandb.Table(columns=columns, data=results)\n",
|
| 331 |
" run.log({'Samples': table, 'version': version})\n",
|
| 332 |
+
" wandb.finish()\n",
|
| 333 |
+
" \n",
|
| 334 |
+
" if add_clip_32: \n",
|
| 335 |
+
" run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip32{suffix}', resume='allow')\n",
|
| 336 |
+
" table = wandb.Table(columns=columns, data=results32)\n",
|
| 337 |
+
" run.log({'Samples': table, 'version': version})\n",
|
| 338 |
+
" wandb.finish()\n",
|
| 339 |
+
" run = None # ensure we don't log on this run"
|
| 340 |
+
]
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
"cell_type": "code",
|
| 344 |
+
"execution_count": null,
|
| 345 |
+
"id": "fdcd09d6-079c-461a-a81a-d9e650d3b099",
|
| 346 |
+
"metadata": {},
|
| 347 |
+
"outputs": [],
|
| 348 |
+
"source": [
|
| 349 |
+
"p_clip32"
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"cell_type": "code",
|
| 354 |
+
"execution_count": null,
|
| 355 |
+
"id": "7d86ceee-c9ac-4860-abad-410cadd16c3c",
|
| 356 |
+
"metadata": {},
|
| 357 |
+
"outputs": [],
|
| 358 |
+
"source": [
|
| 359 |
+
"clip_inputs['attention_mask'].shape, clip_inputs['pixel_values'].shape"
|
| 360 |
+
]
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"cell_type": "code",
|
| 364 |
+
"execution_count": null,
|
| 365 |
+
"id": "fbba4858-da2d-4dd5-97b7-ce3ab4746f96",
|
| 366 |
+
"metadata": {},
|
| 367 |
+
"outputs": [],
|
| 368 |
+
"source": [
|
| 369 |
+
"clip_inputs['input_ids'].shape"
|
| 370 |
]
|
| 371 |
},
|
| 372 |
{
|
|
|
|
| 385 |
{
|
| 386 |
"cell_type": "code",
|
| 387 |
"execution_count": null,
|
| 388 |
+
"id": "a7a5fdf5-3c6e-421b-96a8-5115f730328c",
|
| 389 |
"metadata": {},
|
| 390 |
"outputs": [],
|
| 391 |
+
"source": []
|
|
|
|
|
|
|
| 392 |
}
|
| 393 |
],
|
| 394 |
"metadata": {
|