Spaces:
Running
Running
fix(inference): use float32 + flatten logits
Browse files
tools/inference/inference_pipeline.ipynb
CHANGED
|
@@ -70,7 +70,7 @@
|
|
| 70 |
"# Model references\n",
|
| 71 |
"\n",
|
| 72 |
"# dalle-mini\n",
|
| 73 |
-
"DALLE_MODEL = 'dalle-mini/dalle-mini/model-
|
| 74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
| 75 |
"\n",
|
| 76 |
"# VQGAN model\n",
|
|
@@ -92,7 +92,13 @@
|
|
| 92 |
"import jax.numpy as jnp\n",
|
| 93 |
"\n",
|
| 94 |
"# type used for computation - use bfloat16 on TPU's\n",
|
| 95 |
-
"dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
]
|
| 97 |
},
|
| 98 |
{
|
|
@@ -281,7 +287,7 @@
|
|
| 281 |
},
|
| 282 |
"outputs": [],
|
| 283 |
"source": [
|
| 284 |
-
"prompt = '
|
| 285 |
]
|
| 286 |
},
|
| 287 |
{
|
|
@@ -292,7 +298,8 @@
|
|
| 292 |
},
|
| 293 |
"outputs": [],
|
| 294 |
"source": [
|
| 295 |
-
"processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt"
|
|
|
|
| 296 |
]
|
| 297 |
},
|
| 298 |
{
|
|
@@ -375,7 +382,7 @@
|
|
| 375 |
"outputs": [],
|
| 376 |
"source": [
|
| 377 |
"# number of predictions\n",
|
| 378 |
-
"n_predictions =
|
| 379 |
"\n",
|
| 380 |
"# We can customize top_k/top_p used for generating samples\n",
|
| 381 |
"gen_top_k = None\n",
|
|
@@ -431,7 +438,7 @@
|
|
| 431 |
"# get clip scores\n",
|
| 432 |
"clip_inputs = processor(text=[prompt] * jax.device_count(), images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
| 433 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
| 434 |
-
"logits = logits.squeeze()"
|
| 435 |
]
|
| 436 |
},
|
| 437 |
{
|
|
|
|
| 70 |
"# Model references\n",
|
| 71 |
"\n",
|
| 72 |
"# dalle-mini\n",
|
| 73 |
+
"DALLE_MODEL = 'dalle-mini/dalle-mini/model-3bqwu04f:latest' # can be wandb artifact or 🤗 Hub or local folder\n",
|
| 74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
| 75 |
"\n",
|
| 76 |
"# VQGAN model\n",
|
|
|
|
| 92 |
"import jax.numpy as jnp\n",
|
| 93 |
"\n",
|
| 94 |
"# type used for computation - use bfloat16 on TPU's\n",
|
| 95 |
+
"dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"# TODO:\n",
|
| 98 |
+
"# - we currently have an issue with model.generate() in bfloat16\n",
|
| 99 |
+
"# - https://github.com/google/jax/pull/9089 should fix it\n",
|
| 100 |
+
"# - remove below line and test on TPU with next release of JAX\n",
|
| 101 |
+
"dtype = jnp.float32"
|
| 102 |
]
|
| 103 |
},
|
| 104 |
{
|
|
|
|
| 287 |
},
|
| 288 |
"outputs": [],
|
| 289 |
"source": [
|
| 290 |
+
"prompt = 'a red T-shirt'"
|
| 291 |
]
|
| 292 |
},
|
| 293 |
{
|
|
|
|
| 298 |
},
|
| 299 |
"outputs": [],
|
| 300 |
"source": [
|
| 301 |
+
"processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt\n",
|
| 302 |
+
"processed_prompt"
|
| 303 |
]
|
| 304 |
},
|
| 305 |
{
|
|
|
|
| 382 |
"outputs": [],
|
| 383 |
"source": [
|
| 384 |
"# number of predictions\n",
|
| 385 |
+
"n_predictions = 32\n",
|
| 386 |
"\n",
|
| 387 |
"# We can customize top_k/top_p used for generating samples\n",
|
| 388 |
"gen_top_k = None\n",
|
|
|
|
| 438 |
"# get clip scores\n",
|
| 439 |
"clip_inputs = processor(text=[prompt] * jax.device_count(), images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
| 440 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
| 441 |
+
"logits = logits.squeeze().flatten()"
|
| 442 |
]
|
| 443 |
},
|
| 444 |
{
|
tools/inference/log_inference_samples.ipynb
DELETED
|
@@ -1,434 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": null,
|
| 6 |
-
"id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [],
|
| 9 |
-
"source": [
|
| 10 |
-
"import tempfile\n",
|
| 11 |
-
"from functools import partial\n",
|
| 12 |
-
"import random\n",
|
| 13 |
-
"import numpy as np\n",
|
| 14 |
-
"from PIL import Image\n",
|
| 15 |
-
"from tqdm.notebook import tqdm\n",
|
| 16 |
-
"import jax\n",
|
| 17 |
-
"import jax.numpy as jnp\n",
|
| 18 |
-
"from flax.training.common_utils import shard, shard_prng_key\n",
|
| 19 |
-
"from flax.jax_utils import replicate\n",
|
| 20 |
-
"import wandb\n",
|
| 21 |
-
"from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
|
| 22 |
-
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
| 23 |
-
"from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel\n",
|
| 24 |
-
"from dalle_mini.text import TextNormalizer"
|
| 25 |
-
]
|
| 26 |
-
},
|
| 27 |
-
{
|
| 28 |
-
"cell_type": "code",
|
| 29 |
-
"execution_count": null,
|
| 30 |
-
"id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270",
|
| 31 |
-
"metadata": {},
|
| 32 |
-
"outputs": [],
|
| 33 |
-
"source": [
|
| 34 |
-
"run_ids = [\"63otg87g\"]\n",
|
| 35 |
-
"ENTITY, PROJECT = \"dalle-mini\", \"dalle-mini\" # used only for training run\n",
|
| 36 |
-
"VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
|
| 37 |
-
" \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
|
| 38 |
-
" \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\",\n",
|
| 39 |
-
")\n",
|
| 40 |
-
"latest_only = True # log only latest or all versions\n",
|
| 41 |
-
"suffix = \"\" # mainly for duplicate inference runs with a deleted version\n",
|
| 42 |
-
"add_clip_32 = False"
|
| 43 |
-
]
|
| 44 |
-
},
|
| 45 |
-
{
|
| 46 |
-
"cell_type": "code",
|
| 47 |
-
"execution_count": null,
|
| 48 |
-
"id": "71f27b96-7e6c-4472-a2e4-e99a8fb67a72",
|
| 49 |
-
"metadata": {},
|
| 50 |
-
"outputs": [],
|
| 51 |
-
"source": [
|
| 52 |
-
"# model.generate parameters - Not used yet\n",
|
| 53 |
-
"gen_top_k = None\n",
|
| 54 |
-
"gen_top_p = None\n",
|
| 55 |
-
"temperature = None"
|
| 56 |
-
]
|
| 57 |
-
},
|
| 58 |
-
{
|
| 59 |
-
"cell_type": "code",
|
| 60 |
-
"execution_count": null,
|
| 61 |
-
"id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3",
|
| 62 |
-
"metadata": {},
|
| 63 |
-
"outputs": [],
|
| 64 |
-
"source": [
|
| 65 |
-
"batch_size = 8\n",
|
| 66 |
-
"num_images = 128\n",
|
| 67 |
-
"top_k = 8\n",
|
| 68 |
-
"text_normalizer = TextNormalizer()\n",
|
| 69 |
-
"padding_item = \"NONE\"\n",
|
| 70 |
-
"seed = random.randint(0, 2 ** 32 - 1)\n",
|
| 71 |
-
"key = jax.random.PRNGKey(seed)\n",
|
| 72 |
-
"api = wandb.Api()"
|
| 73 |
-
]
|
| 74 |
-
},
|
| 75 |
-
{
|
| 76 |
-
"cell_type": "code",
|
| 77 |
-
"execution_count": null,
|
| 78 |
-
"id": "c6a878fa-4bf5-4978-abb5-e235841d765b",
|
| 79 |
-
"metadata": {},
|
| 80 |
-
"outputs": [],
|
| 81 |
-
"source": [
|
| 82 |
-
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
| 83 |
-
"vqgan_params = replicate(vqgan.params)\n",
|
| 84 |
-
"\n",
|
| 85 |
-
"clip16 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
| 86 |
-
"processor16 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
| 87 |
-
"clip16_params = replicate(clip16.params)\n",
|
| 88 |
-
"\n",
|
| 89 |
-
"if add_clip_32:\n",
|
| 90 |
-
" clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
| 91 |
-
" processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
| 92 |
-
" clip32_params = replicate(clip32.params)"
|
| 93 |
-
]
|
| 94 |
-
},
|
| 95 |
-
{
|
| 96 |
-
"cell_type": "code",
|
| 97 |
-
"execution_count": null,
|
| 98 |
-
"id": "a500dd07-dbc3-477d-80d4-2b73a3b83ef3",
|
| 99 |
-
"metadata": {},
|
| 100 |
-
"outputs": [],
|
| 101 |
-
"source": [
|
| 102 |
-
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 103 |
-
"def p_decode(indices, params):\n",
|
| 104 |
-
" return vqgan.decode_code(indices, params=params)\n",
|
| 105 |
-
"\n",
|
| 106 |
-
"\n",
|
| 107 |
-
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 108 |
-
"def p_clip16(inputs, params):\n",
|
| 109 |
-
" logits = clip16(params=params, **inputs).logits_per_image\n",
|
| 110 |
-
" return logits\n",
|
| 111 |
-
"\n",
|
| 112 |
-
"\n",
|
| 113 |
-
"if add_clip_32:\n",
|
| 114 |
-
"\n",
|
| 115 |
-
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
| 116 |
-
" def p_clip32(inputs, params):\n",
|
| 117 |
-
" logits = clip32(params=params, **inputs).logits_per_image\n",
|
| 118 |
-
" return logits"
|
| 119 |
-
]
|
| 120 |
-
},
|
| 121 |
-
{
|
| 122 |
-
"cell_type": "code",
|
| 123 |
-
"execution_count": null,
|
| 124 |
-
"id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
|
| 125 |
-
"metadata": {},
|
| 126 |
-
"outputs": [],
|
| 127 |
-
"source": [
|
| 128 |
-
"with open(\"samples.txt\", encoding=\"utf8\") as f:\n",
|
| 129 |
-
" samples = [l.strip() for l in f.readlines()]\n",
|
| 130 |
-
" # make list multiple of batch_size by adding elements\n",
|
| 131 |
-
" samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
|
| 132 |
-
" samples.extend(samples_to_add)\n",
|
| 133 |
-
" # reshape\n",
|
| 134 |
-
" samples = [samples[i : i + batch_size] for i in range(0, len(samples), batch_size)]"
|
| 135 |
-
]
|
| 136 |
-
},
|
| 137 |
-
{
|
| 138 |
-
"cell_type": "code",
|
| 139 |
-
"execution_count": null,
|
| 140 |
-
"id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614",
|
| 141 |
-
"metadata": {},
|
| 142 |
-
"outputs": [],
|
| 143 |
-
"source": [
|
| 144 |
-
"def get_artifact_versions(run_id, latest_only=False):\n",
|
| 145 |
-
" try:\n",
|
| 146 |
-
" if latest_only:\n",
|
| 147 |
-
" return [\n",
|
| 148 |
-
" api.artifact(\n",
|
| 149 |
-
" type=\"bart_model\", name=f\"{ENTITY}/{PROJECT}/model-{run_id}:latest\"\n",
|
| 150 |
-
" )\n",
|
| 151 |
-
" ]\n",
|
| 152 |
-
" else:\n",
|
| 153 |
-
" return api.artifact_versions(\n",
|
| 154 |
-
" type_name=\"bart_model\",\n",
|
| 155 |
-
" name=f\"{ENTITY}/{PROJECT}/model-{run_id}\",\n",
|
| 156 |
-
" per_page=10000,\n",
|
| 157 |
-
" )\n",
|
| 158 |
-
" except:\n",
|
| 159 |
-
" return []"
|
| 160 |
-
]
|
| 161 |
-
},
|
| 162 |
-
{
|
| 163 |
-
"cell_type": "code",
|
| 164 |
-
"execution_count": null,
|
| 165 |
-
"id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570",
|
| 166 |
-
"metadata": {},
|
| 167 |
-
"outputs": [],
|
| 168 |
-
"source": [
|
| 169 |
-
"def get_training_config(run_id):\n",
|
| 170 |
-
" training_run = api.run(f\"{ENTITY}/{PROJECT}/{run_id}\")\n",
|
| 171 |
-
" config = training_run.config\n",
|
| 172 |
-
" return config"
|
| 173 |
-
]
|
| 174 |
-
},
|
| 175 |
-
{
|
| 176 |
-
"cell_type": "code",
|
| 177 |
-
"execution_count": null,
|
| 178 |
-
"id": "7e784a43-626d-4e8d-9e47-a23775b2f35f",
|
| 179 |
-
"metadata": {},
|
| 180 |
-
"outputs": [],
|
| 181 |
-
"source": [
|
| 182 |
-
"# retrieve inference run details\n",
|
| 183 |
-
"def get_last_inference_version(run_id):\n",
|
| 184 |
-
" try:\n",
|
| 185 |
-
" inference_run = api.run(f\"dalle-mini/dalle-mini/{run_id}-clip16{suffix}\")\n",
|
| 186 |
-
" return inference_run.summary.get(\"version\", None)\n",
|
| 187 |
-
" except:\n",
|
| 188 |
-
" return None"
|
| 189 |
-
]
|
| 190 |
-
},
|
| 191 |
-
{
|
| 192 |
-
"cell_type": "code",
|
| 193 |
-
"execution_count": null,
|
| 194 |
-
"id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
|
| 195 |
-
"metadata": {},
|
| 196 |
-
"outputs": [],
|
| 197 |
-
"source": [
|
| 198 |
-
"# compile functions - needed only once per run\n",
|
| 199 |
-
"def pmap_model_function(model):\n",
|
| 200 |
-
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
| 201 |
-
" def _generate(tokenized_prompt, key, params):\n",
|
| 202 |
-
" return model.generate(\n",
|
| 203 |
-
" **tokenized_prompt,\n",
|
| 204 |
-
" do_sample=True,\n",
|
| 205 |
-
" num_beams=1,\n",
|
| 206 |
-
" prng_key=key,\n",
|
| 207 |
-
" params=params,\n",
|
| 208 |
-
" top_k=gen_top_k,\n",
|
| 209 |
-
" top_p=gen_top_p\n",
|
| 210 |
-
" )\n",
|
| 211 |
-
"\n",
|
| 212 |
-
" return _generate"
|
| 213 |
-
]
|
| 214 |
-
},
|
| 215 |
-
{
|
| 216 |
-
"cell_type": "code",
|
| 217 |
-
"execution_count": null,
|
| 218 |
-
"id": "23b2444c-67a9-44d7-abd1-187ed83a9431",
|
| 219 |
-
"metadata": {},
|
| 220 |
-
"outputs": [],
|
| 221 |
-
"source": [
|
| 222 |
-
"run_id = run_ids[0]\n",
|
| 223 |
-
"# TODO: loop over runs"
|
| 224 |
-
]
|
| 225 |
-
},
|
| 226 |
-
{
|
| 227 |
-
"cell_type": "code",
|
| 228 |
-
"execution_count": null,
|
| 229 |
-
"id": "bba70f33-af8b-4eb3-9973-7be672301a0b",
|
| 230 |
-
"metadata": {},
|
| 231 |
-
"outputs": [],
|
| 232 |
-
"source": [
|
| 233 |
-
"artifact_versions = get_artifact_versions(run_id, latest_only)\n",
|
| 234 |
-
"last_inference_version = get_last_inference_version(run_id)\n",
|
| 235 |
-
"training_config = get_training_config(run_id)\n",
|
| 236 |
-
"run = None\n",
|
| 237 |
-
"p_generate = None\n",
|
| 238 |
-
"model_files = [\n",
|
| 239 |
-
" \"config.json\",\n",
|
| 240 |
-
" \"flax_model.msgpack\",\n",
|
| 241 |
-
" \"merges.txt\",\n",
|
| 242 |
-
" \"special_tokens_map.json\",\n",
|
| 243 |
-
" \"tokenizer.json\",\n",
|
| 244 |
-
" \"tokenizer_config.json\",\n",
|
| 245 |
-
" \"vocab.json\",\n",
|
| 246 |
-
"]\n",
|
| 247 |
-
"for artifact in artifact_versions:\n",
|
| 248 |
-
" print(f\"Processing artifact: {artifact.name}\")\n",
|
| 249 |
-
" version = int(artifact.version[1:])\n",
|
| 250 |
-
" results16, results32 = [], []\n",
|
| 251 |
-
" columns = [\"Caption\"] + [f\"Image {i+1}\" for i in range(top_k)]\n",
|
| 252 |
-
"\n",
|
| 253 |
-
" if latest_only:\n",
|
| 254 |
-
" assert last_inference_version is None or version > last_inference_version\n",
|
| 255 |
-
" else:\n",
|
| 256 |
-
" if last_inference_version is None:\n",
|
| 257 |
-
" # we should start from v0\n",
|
| 258 |
-
" assert version == 0\n",
|
| 259 |
-
" elif version <= last_inference_version:\n",
|
| 260 |
-
" print(\n",
|
| 261 |
-
" f\"v{version} has already been logged (versions logged up to v{last_inference_version}\"\n",
|
| 262 |
-
" )\n",
|
| 263 |
-
" else:\n",
|
| 264 |
-
" # check we are logging the correct version\n",
|
| 265 |
-
" assert version == last_inference_version + 1\n",
|
| 266 |
-
"\n",
|
| 267 |
-
" # start/resume corresponding run\n",
|
| 268 |
-
" if run is None:\n",
|
| 269 |
-
" run = wandb.init(\n",
|
| 270 |
-
" job_type=\"inference\",\n",
|
| 271 |
-
" entity=\"dalle-mini\",\n",
|
| 272 |
-
" project=\"dalle-mini\",\n",
|
| 273 |
-
" config=training_config,\n",
|
| 274 |
-
" id=f\"{run_id}-clip16{suffix}\",\n",
|
| 275 |
-
" resume=\"allow\",\n",
|
| 276 |
-
" )\n",
|
| 277 |
-
"\n",
|
| 278 |
-
" # work in temporary directory\n",
|
| 279 |
-
" with tempfile.TemporaryDirectory() as tmp:\n",
|
| 280 |
-
"\n",
|
| 281 |
-
" # download model files\n",
|
| 282 |
-
" artifact = run.use_artifact(artifact)\n",
|
| 283 |
-
" for f in model_files:\n",
|
| 284 |
-
" artifact.get_path(f).download(tmp)\n",
|
| 285 |
-
"\n",
|
| 286 |
-
" # load tokenizer and model\n",
|
| 287 |
-
" tokenizer = BartTokenizer.from_pretrained(tmp)\n",
|
| 288 |
-
" model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)\n",
|
| 289 |
-
" model_params = replicate(model.params)\n",
|
| 290 |
-
"\n",
|
| 291 |
-
" # pmap model function needs to happen only once per model config\n",
|
| 292 |
-
" if p_generate is None:\n",
|
| 293 |
-
" p_generate = pmap_model_function(model)\n",
|
| 294 |
-
"\n",
|
| 295 |
-
" # process one batch of captions\n",
|
| 296 |
-
" for batch in tqdm(samples):\n",
|
| 297 |
-
" processed_prompts = (\n",
|
| 298 |
-
" [text_normalizer(x) for x in batch]\n",
|
| 299 |
-
" if model.config.normalize_text\n",
|
| 300 |
-
" else list(batch)\n",
|
| 301 |
-
" )\n",
|
| 302 |
-
"\n",
|
| 303 |
-
" # repeat the prompts to distribute over each device and tokenize\n",
|
| 304 |
-
" processed_prompts = processed_prompts * jax.device_count()\n",
|
| 305 |
-
" tokenized_prompt = tokenizer(\n",
|
| 306 |
-
" processed_prompts,\n",
|
| 307 |
-
" return_tensors=\"jax\",\n",
|
| 308 |
-
" padding=\"max_length\",\n",
|
| 309 |
-
" truncation=True,\n",
|
| 310 |
-
" max_length=128,\n",
|
| 311 |
-
" ).data\n",
|
| 312 |
-
" tokenized_prompt = shard(tokenized_prompt)\n",
|
| 313 |
-
"\n",
|
| 314 |
-
" # generate images\n",
|
| 315 |
-
" images = []\n",
|
| 316 |
-
" pbar = tqdm(\n",
|
| 317 |
-
" range(num_images // jax.device_count()),\n",
|
| 318 |
-
" desc=\"Generating Images\",\n",
|
| 319 |
-
" leave=True,\n",
|
| 320 |
-
" )\n",
|
| 321 |
-
" for i in pbar:\n",
|
| 322 |
-
" key, subkey = jax.random.split(key)\n",
|
| 323 |
-
" encoded_images = p_generate(\n",
|
| 324 |
-
" tokenized_prompt, shard_prng_key(subkey), model_params\n",
|
| 325 |
-
" )\n",
|
| 326 |
-
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
| 327 |
-
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
| 328 |
-
" decoded_images = decoded_images.clip(0.0, 1.0).reshape(\n",
|
| 329 |
-
" (-1, 256, 256, 3)\n",
|
| 330 |
-
" )\n",
|
| 331 |
-
" for img in decoded_images:\n",
|
| 332 |
-
" images.append(\n",
|
| 333 |
-
" Image.fromarray(np.asarray(img * 255, dtype=np.uint8))\n",
|
| 334 |
-
" )\n",
|
| 335 |
-
"\n",
|
| 336 |
-
" def add_clip_results(results, processor, p_clip, clip_params):\n",
|
| 337 |
-
" clip_inputs = processor(\n",
|
| 338 |
-
" text=batch,\n",
|
| 339 |
-
" images=images,\n",
|
| 340 |
-
" return_tensors=\"np\",\n",
|
| 341 |
-
" padding=\"max_length\",\n",
|
| 342 |
-
" max_length=77,\n",
|
| 343 |
-
" truncation=True,\n",
|
| 344 |
-
" ).data\n",
|
| 345 |
-
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
| 346 |
-
" images_per_prompt_indices = np.asarray(\n",
|
| 347 |
-
" range(0, len(images), batch_size)\n",
|
| 348 |
-
" )\n",
|
| 349 |
-
" clip_inputs[\"pixel_values\"] = jnp.concatenate(\n",
|
| 350 |
-
" list(\n",
|
| 351 |
-
" clip_inputs[\"pixel_values\"][images_per_prompt_indices + i]\n",
|
| 352 |
-
" for i in range(batch_size)\n",
|
| 353 |
-
" )\n",
|
| 354 |
-
" )\n",
|
| 355 |
-
" clip_inputs = shard(clip_inputs)\n",
|
| 356 |
-
" logits = p_clip(clip_inputs, clip_params)\n",
|
| 357 |
-
" logits = logits.reshape(-1, num_images)\n",
|
| 358 |
-
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
| 359 |
-
" logits = jax.device_get(logits)\n",
|
| 360 |
-
" # add to results table\n",
|
| 361 |
-
" for i, (idx, scores, sample) in enumerate(\n",
|
| 362 |
-
" zip(top_scores, logits, batch)\n",
|
| 363 |
-
" ):\n",
|
| 364 |
-
" if sample == padding_item:\n",
|
| 365 |
-
" continue\n",
|
| 366 |
-
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
| 367 |
-
" top_images = [\n",
|
| 368 |
-
" wandb.Image(cur_images[x], caption=f\"Score: {scores[x]:.2f}\")\n",
|
| 369 |
-
" for x in idx\n",
|
| 370 |
-
" ]\n",
|
| 371 |
-
" results.append([sample] + top_images)\n",
|
| 372 |
-
"\n",
|
| 373 |
-
" # get clip scores\n",
|
| 374 |
-
" pbar.set_description(\"Calculating CLIP 16 scores\")\n",
|
| 375 |
-
" add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
|
| 376 |
-
"\n",
|
| 377 |
-
" # get clip 32 scores\n",
|
| 378 |
-
" if add_clip_32:\n",
|
| 379 |
-
" pbar.set_description(\"Calculating CLIP 32 scores\")\n",
|
| 380 |
-
" add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
|
| 381 |
-
"\n",
|
| 382 |
-
" pbar.close()\n",
|
| 383 |
-
"\n",
|
| 384 |
-
" # log results\n",
|
| 385 |
-
" table = wandb.Table(columns=columns, data=results16)\n",
|
| 386 |
-
" run.log({\"Samples\": table, \"version\": version})\n",
|
| 387 |
-
" wandb.finish()\n",
|
| 388 |
-
"\n",
|
| 389 |
-
" if add_clip_32:\n",
|
| 390 |
-
" run = wandb.init(\n",
|
| 391 |
-
" job_type=\"inference\",\n",
|
| 392 |
-
" entity=\"dalle-mini\",\n",
|
| 393 |
-
" project=\"dalle-mini\",\n",
|
| 394 |
-
" config=training_config,\n",
|
| 395 |
-
" id=f\"{run_id}-clip32{suffix}\",\n",
|
| 396 |
-
" resume=\"allow\",\n",
|
| 397 |
-
" )\n",
|
| 398 |
-
" table = wandb.Table(columns=columns, data=results32)\n",
|
| 399 |
-
" run.log({\"Samples\": table, \"version\": version})\n",
|
| 400 |
-
" wandb.finish()\n",
|
| 401 |
-
" run = None # ensure we don't log on this run"
|
| 402 |
-
]
|
| 403 |
-
},
|
| 404 |
-
{
|
| 405 |
-
"cell_type": "code",
|
| 406 |
-
"execution_count": null,
|
| 407 |
-
"id": "415d3f54-7226-43de-9eea-4283a948dc93",
|
| 408 |
-
"metadata": {},
|
| 409 |
-
"outputs": [],
|
| 410 |
-
"source": []
|
| 411 |
-
}
|
| 412 |
-
],
|
| 413 |
-
"metadata": {
|
| 414 |
-
"kernelspec": {
|
| 415 |
-
"display_name": "Python 3 (ipykernel)",
|
| 416 |
-
"language": "python",
|
| 417 |
-
"name": "python3"
|
| 418 |
-
},
|
| 419 |
-
"language_info": {
|
| 420 |
-
"codemirror_mode": {
|
| 421 |
-
"name": "ipython",
|
| 422 |
-
"version": 3
|
| 423 |
-
},
|
| 424 |
-
"file_extension": ".py",
|
| 425 |
-
"mimetype": "text/x-python",
|
| 426 |
-
"name": "python",
|
| 427 |
-
"nbconvert_exporter": "python",
|
| 428 |
-
"pygments_lexer": "ipython3",
|
| 429 |
-
"version": "3.9.7"
|
| 430 |
-
}
|
| 431 |
-
},
|
| 432 |
-
"nbformat": 4,
|
| 433 |
-
"nbformat_minor": 5
|
| 434 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/inference/samples.txt
DELETED
|
@@ -1,124 +0,0 @@
|
|
| 1 |
-
t-shirt, size M
|
| 2 |
-
flower dress, size M
|
| 3 |
-
white snow covered mountain under blue sky during daytime
|
| 4 |
-
aerial view of the beach during daytime
|
| 5 |
-
aerial view of the beach at night
|
| 6 |
-
a beautiful sunset at a beach with a shell on the shore
|
| 7 |
-
a farmhouse surrounded by beautiful flowers
|
| 8 |
-
sunset over green mountains
|
| 9 |
-
a photo of san francisco golden gate bridge
|
| 10 |
-
painting of an oniric forest glade surrounded by tall trees
|
| 11 |
-
a graphite sketch of a gothic cathedral
|
| 12 |
-
a graphite sketch of Elon Musk
|
| 13 |
-
still life in the style of Kandinsky
|
| 14 |
-
still life in the style of Picasso
|
| 15 |
-
a colorful stairway to heaven
|
| 16 |
-
a background consisting of colors blue, green, and red
|
| 17 |
-
Mohammed Ali and Mike Tyson in a match
|
| 18 |
-
Pele and Maradona in a match
|
| 19 |
-
view of Mars from space
|
| 20 |
-
a picture of the Eiffel tower on the moon
|
| 21 |
-
a picture of the Eiffel tower on the moon, Earth is in the background
|
| 22 |
-
watercolor of the Eiffel tower on the moon
|
| 23 |
-
the moon is a skull
|
| 24 |
-
epic sword fight
|
| 25 |
-
underwater cathedral
|
| 26 |
-
a photo of a fantasy version of New York City
|
| 27 |
-
a picture of fantasy kingdoms
|
| 28 |
-
a volcano erupting next to San Francisco golden gate bridge
|
| 29 |
-
Paris in a far future, futuristic Paris
|
| 30 |
-
real painting of an alien from Monet
|
| 31 |
-
the communist statue of liberty
|
| 32 |
-
robots taking control over humans
|
| 33 |
-
illustration of an astronaut in a space suit playing guitar
|
| 34 |
-
a clown wearing a spacesuit floating in space
|
| 35 |
-
a dog playing with a ball
|
| 36 |
-
a cat sits on top of an alligator
|
| 37 |
-
a very cute cat laying by a big bike
|
| 38 |
-
a rat holding a red lightsaber in a white background
|
| 39 |
-
a very cute giraffe making a funny face
|
| 40 |
-
A unicorn is passing by a rainbow in a field of flowers
|
| 41 |
-
an elephant made of carrots
|
| 42 |
-
an elephant on a unicycle during a circus
|
| 43 |
-
photography of a penguin watching television
|
| 44 |
-
a penguin is walking on the Moon, Earth is in the background
|
| 45 |
-
a penguin standing on a tower of books holds onto a rope from a helicopter
|
| 46 |
-
rat wearing a crown
|
| 47 |
-
looking into the sky, 10 airplanes are seen overhead
|
| 48 |
-
shelves filled with books and alchemy potion bottles
|
| 49 |
-
this is a detailed high-resolution scan of a human brain
|
| 50 |
-
a restaurant menu
|
| 51 |
-
a bottle of coca-cola on a table
|
| 52 |
-
a peanut
|
| 53 |
-
a cross-section view of a walnut
|
| 54 |
-
a living room with two white armchairs and a painting of the collosseum. The painting is mounted above a modern fireplace.
|
| 55 |
-
a long line of alternating green and red blocks
|
| 56 |
-
a long line of green blocks on a beach at subset
|
| 57 |
-
a long line of peaches on a beach at sunset
|
| 58 |
-
a picture of a castle from minecraft
|
| 59 |
-
a cute pikachu teapot
|
| 60 |
-
an illustration of pikachu sitting on a bench eating an ice cream
|
| 61 |
-
mario is jumping over a zebra
|
| 62 |
-
famous anime hero
|
| 63 |
-
star wars concept art
|
| 64 |
-
Cartoon of a carrot with big eyes
|
| 65 |
-
a cartoon of a superhero bear
|
| 66 |
-
an illustration of a cute skeleton wearing a blue hoodie
|
| 67 |
-
illustration of a baby shark swimming around corals
|
| 68 |
-
an illustration of an avocado in a beanie riding a motorcycle
|
| 69 |
-
logo of a robot wearing glasses and reading a book
|
| 70 |
-
illustration of a cactus lifting weigths
|
| 71 |
-
logo of a cactus lifting weights
|
| 72 |
-
a photo of a camera from the future
|
| 73 |
-
a skeleton with the shape of a spider
|
| 74 |
-
a collection of glasses is sitting on a table
|
| 75 |
-
a painting of a capybara sitting on a mountain during fall in surrealist style
|
| 76 |
-
a pentagonal green clock
|
| 77 |
-
a small red block sitting on a large green block
|
| 78 |
-
a storefront that has the word 'openai' written on it
|
| 79 |
-
a tatoo of a black broccoli
|
| 80 |
-
a variety of clocks is sitting on a table
|
| 81 |
-
a table has a train model on it with other cars and things
|
| 82 |
-
a pixel art illustration of an eagle sitting in a field in the afternoon
|
| 83 |
-
an emoji of a baby fox wearing a blue hat, green gloves, red shirt, and yellow pants
|
| 84 |
-
an emoji of a baby penguin wearing a blue hat, blue gloves, red shirt, and green pants
|
| 85 |
-
an extreme close-up view of a capybara sitting in a field
|
| 86 |
-
an illustration of a baby cucumber with a mustache playing chess
|
| 87 |
-
an illustration of a baby daikon radish in a tutu walking a dog
|
| 88 |
-
an illustration of a baby hedgehog in a cape staring at its reflection in a mirror
|
| 89 |
-
an illustration of a baby panda with headphones holding an umbrella in the rain
|
| 90 |
-
urinals are lined up in a jungle
|
| 91 |
-
a muscular banana sitting upright on a bench smoking watching a banana on television, high definition photography
|
| 92 |
-
a human face
|
| 93 |
-
a person is holding a phone and a waterbottle, running a marathon
|
| 94 |
-
a child eating a birthday cake near some balloons
|
| 95 |
-
Young woman riding her bike through the forest
|
| 96 |
-
the best soccer team of the world
|
| 97 |
-
the best football team of the world
|
| 98 |
-
the best basketball team of the world
|
| 99 |
-
happy, happiness
|
| 100 |
-
sad, sadness
|
| 101 |
-
the representation of infinity
|
| 102 |
-
the end of the world
|
| 103 |
-
the last sunrise on earth
|
| 104 |
-
a portrait of a nightmare creature watching at you
|
| 105 |
-
an avocado armchair
|
| 106 |
-
an armchair in the shape of an avocado
|
| 107 |
-
illustration of an avocado armchair
|
| 108 |
-
illustration of an armchair in the shape of an avocado
|
| 109 |
-
logo of an avocado armchair
|
| 110 |
-
an avocado armchair flying into space
|
| 111 |
-
a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps
|
| 112 |
-
an illustration of an avocado in a christmas sweater staring at its reflection in a mirror
|
| 113 |
-
illustration of an avocado armchair getting married to a pineapple
|
| 114 |
-
half human half cat
|
| 115 |
-
half human half dog
|
| 116 |
-
half human half pen
|
| 117 |
-
half human half garbage
|
| 118 |
-
half human half avocado
|
| 119 |
-
half human half Eiffel tower
|
| 120 |
-
a propaganda poster for transhumanism
|
| 121 |
-
a propaganda poster for building a space elevator
|
| 122 |
-
a beautiful epic fantasy painting of a space elevator
|
| 123 |
-
a transformer architecture
|
| 124 |
-
a transformer in real life
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|