from typing import * import os import argparse import json from multiprocessing import Process from tqdm import tqdm import numpy as np import pandas as pd import torch import torch.nn.functional as F from transformers import ( CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast, ) @torch.no_grad() def text_encode( text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], tokenizer: CLIPTokenizer, oids: List[str], gpu_id: int, text_encoder_2: Optional[Union[CLIPTextModelWithProjection, T5EncoderModel]] = None, tokenizer_2: Optional[Union[CLIPTokenizer, T5Tokenizer]] = None, text_encoder_3: Optional[T5EncoderModel] = None, tokenizer_3: Optional[T5TokenizerFast] = None, ): global caption_dict, MODEL_NAME, BATCH, use_special_words, dataset_name device = f"cuda:{gpu_id}" text_encoder = text_encoder.to(device) if MODEL_NAME in ["sdxl", "sd3m", "sd35m", "sd35l"]: assert text_encoder_2 is not None and tokenizer_2 is not None text_encoder_2 = text_encoder_2.to(device) if MODEL_NAME in ["sd3m", "sd35m", "sd35l"]: assert text_encoder_3 is not None and tokenizer_3 is not None text_encoder_3 = text_encoder_3.to(device) for i in tqdm(range(0, len(oids), BATCH), desc=pretrained_model_name_or_path, ncols=125): batch_oids = oids[i:min(i+BATCH, len(oids))] batch_captions = [ "3d asset in the sks style: " if use_special_words else "" + caption_dict[oid] for oid in batch_oids ] if MODEL_NAME in ["sd15", "sd21"]: batch_text_inputs = tokenizer( batch_captions, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) batch_prompt_embeds = text_encoder(batch_text_input_ids) batch_prompt_embeds = batch_prompt_embeds[0] # (B, N, D) elif MODEL_NAME in ["sdxl"]: batch_text_inputs = tokenizer( batch_captions, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) batch_prompt_embeds = text_encoder(batch_text_input_ids, output_hidden_states=True) batch_prompt_embeds_1 = batch_prompt_embeds.hidden_states[-2] # (B, N, D1); `-2` because SDXL always indexes from the penultimate layer # Text encoder 2 batch_text_inputs = tokenizer_2( batch_captions, padding="max_length", max_length=tokenizer_2.model_max_length, truncation=True, return_tensors="pt", ) batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) batch_prompt_embeds = text_encoder_2(batch_text_input_ids, output_hidden_states=True) batch_pooled_prompt_embeds = batch_prompt_embeds.text_embeds # (B, D2) batch_prompt_embeds_2 = batch_prompt_embeds.hidden_states[-2] # (B, N, D); `-2` because SDXL always indexes from the penultimate layer batch_prompt_embeds = torch.cat([batch_prompt_embeds_1, batch_prompt_embeds_2], dim=-1) # (B, N, D1+D2) elif MODEL_NAME in ["paa", "pas"]: max_length = {"paa": 120, "pas": 300} # hard-coded for PAA and PAS batch_captions = [t.lower().strip() for t in batch_captions] batch_text_inputs = tokenizer( batch_captions, padding="max_length", max_length=max_length[MODEL_NAME], truncation=True, add_special_tokens=True, return_tensors="pt", ) batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) batch_prompt_attention_mask = batch_text_inputs.attention_mask.to(device) # (B, N) batch_prompt_embeds = text_encoder(batch_text_input_ids, attention_mask=batch_prompt_attention_mask) batch_prompt_embeds = batch_prompt_embeds[0] # (B, N, D) elif MODEL_NAME in ["sd3m", "sd35m", "sd35l"]: batch_text_inputs = tokenizer( batch_captions, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) batch_prompt_embeds = text_encoder(batch_text_input_ids, output_hidden_states=True) batch_pooled_prompt_embeds_1 = batch_prompt_embeds.text_embeds # (B, D) batch_prompt_embeds_1 = batch_prompt_embeds.hidden_states[-2] # (B, N, D); `-2` because SD3(.5) always indexes from the penultimate layer # Text encoder 2 batch_text_inputs = tokenizer_2( batch_captions, padding="max_length", max_length=tokenizer_2.model_max_length, truncation=True, return_tensors="pt", ) batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) batch_prompt_embeds = text_encoder_2(batch_text_input_ids, output_hidden_states=True) batch_pooled_prompt_embeds_2 = batch_prompt_embeds.text_embeds # (B, D) batch_pooled_prompt_embeds = torch.cat([batch_pooled_prompt_embeds_1, batch_pooled_prompt_embeds_2], dim=-1) # (B, D1+D2) batch_prompt_embeds_2 = batch_prompt_embeds.hidden_states[-2] # (B, N, D); `-2` because SD3(.5) always indexes from the penultimate layer batch_clip_prompt_embeds = torch.cat([batch_prompt_embeds_1, batch_prompt_embeds_2], dim=-1) # (B, N, D1+D2) # Text encoder 3 batch_text_inputs = tokenizer_3( batch_captions, padding="max_length", max_length=256, # hard-coded for SD3(.5) truncation=True, return_tensors="pt", ) batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N3) batch_prompt_embeds = text_encoder_3(batch_text_input_ids) batch_prompt_embeds_3 = batch_prompt_embeds[0] # (B, N3, D3) batch_clip_prompt_embeds = F.pad( batch_clip_prompt_embeds, (0, batch_prompt_embeds_3.shape[-1] - batch_clip_prompt_embeds.shape[-1]), ) # (B, N, D3) batch_prompt_embeds = torch.cat([batch_clip_prompt_embeds, batch_prompt_embeds_3], dim=-2) # (B, N+N3, D3) DATASET_NAME = { "gobj265k": "GObjaverse", "gobj83k": "GObjaverse", }[dataset_name] dir = f"/tmp/{DATASET_NAME}_{MODEL_NAME}_prompt_embeds" os.makedirs(dir, exist_ok=True) for j, oid in enumerate(batch_oids): np.save(f"{dir}/{oid}.npy", batch_prompt_embeds[j].float().cpu().numpy()) if MODEL_NAME in ["sdxl", "sd3m", "sd35m", "sd35l"]: np.save(f"{dir}/{oid}_pooled.npy", batch_pooled_prompt_embeds[j].float().cpu().numpy()) if MODEL_NAME in ["paa", "pas"]: np.save(f"{dir}/{oid}_attention_mask.npy", batch_prompt_attention_mask[j].float().cpu().numpy()) if __name__ == "__main__": args = argparse.ArgumentParser("Encode prompt embeddings") args.add_argument("model_name", type=str, choices=["sd15", "sd21", "sdxl", "paa", "pas", "sd3m", "sd35m", "sd35l"]) args.add_argument("--batch_size", type=int, default=128) args.add_argument("--dataset_name", default="gobj83k", choices=["gobj265k", "gobj83k"]) args.add_argument("--use_special_words", action="store_true") args = args.parse_args() MODEL_NAME = args.model_name pretrained_model_name_or_path = { "sd15": "chenguolin/stable-diffusion-v1-5", "sd21": "stabilityai/stable-diffusion-2-1", "sdxl": "stabilityai/stable-diffusion-xl-base-1.0", "paa": "PixArt-alpha/PixArt-XL-2-512x512", # "PixArt-alpha/PixArt-XL-2-1024-MS" "pas": "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", "sd3m": "stabilityai/stable-diffusion-3-medium-diffusers", "sd35m": "stabilityai/stable-diffusion-3.5-medium", "sd35l": "stabilityai/stable-diffusion-3.5-large", }[MODEL_NAME] NUM_GPU = torch.cuda.device_count() BATCH = args.batch_size dataset_name = args.dataset_name use_special_words = args.use_special_words variant = "fp16" if MODEL_NAME not in ["pas", "sd3m", "sd35m", "sd35l"] else None # text encoders of PAS and SD3(.5) are already in fp16 if MODEL_NAME in ["sd15", "sdxl"]: tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", variant=variant) elif MODEL_NAME in ["paa", "pas"]: tokenizer = T5Tokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", variant=variant) elif MODEL_NAME in ["sd3m", "sd35m", "sd35l"]: tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") text_encoder = CLIPTextModelWithProjection.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", variant=variant) if MODEL_NAME in ["sdxl", "sd3m", "sd35m", "sd35l"]: tokenizer_2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2") text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder_2", variant=variant) else: tokenizer_2 = None text_encoder_2 = None if MODEL_NAME in ["sd3m", "sd35m", "sd35l"]: tokenizer_3 = T5TokenizerFast.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_3") text_encoder_3 = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder_3", variant=variant) else: tokenizer_3 = None text_encoder_3 = None # GObjaverse Cap3D if "gobj" in dataset_name: if not os.path.exists("extensions/assets/Cap3D_automated_Objaverse_full.csv"): os.system("wget https://huggingface.co/datasets/tiange/Cap3D/resolve/main/Cap3D_automated_Objaverse_full.csv -P extensions/assets/") captions = pd.read_csv("extensions/assets/Cap3D_automated_Objaverse_full.csv", header=None) caption_dict = {} for i in tqdm(range(len(captions)), desc="Preparing caption dict", ncols=125): caption_dict[captions.iloc[i][0]] = captions.iloc[i][1] if not os.path.exists("extensions/assets/gobjaverse_280k_index_to_objaverse.json"): os.system("wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/gobjaverse_280k_index_to_objaverse.json -P extensions/assets/") gids_to_oids = json.load(open("extensions/assets/gobjaverse_280k_index_to_objaverse.json", "r")) if dataset_name == "gobj83k": # 83k subset if not os.path.exists("extensions/assets/gobj_merged.json"): os.system("wget https://raw.githubusercontent.com/ashawkey/objaverse_filter/main/gobj_merged.json -P extensions/assets/") gids = json.load(open("extensions/assets/gobj_merged.json", "r")) all_oids = [gids_to_oids[gid].split("/")[1].split(".")[0] for gid in gids] elif dataset_name == "gobj265k": # GObjaverse all 265k all_oids = [oid.split("/")[1].split(".")[0] for oid in gids_to_oids.values()] assert all(oid in caption_dict.keys() for oid in all_oids) oids_split = np.array_split(all_oids, NUM_GPU) processes = [ Process( target=text_encode, args=(text_encoder, tokenizer, oids_split[i], i, text_encoder_2, tokenizer_2, text_encoder_3, tokenizer_3), ) for i in range(NUM_GPU) ] for p in processes: p.start() for p in processes: p.join() with torch.no_grad(): if MODEL_NAME in ["sd15", "sd21"]: null_text_inputs = tokenizer( "", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) null_text_input_ids = null_text_inputs.input_ids # (1, N) null_prompt_embed = text_encoder(null_text_input_ids) null_prompt_embed = null_prompt_embed[0].squeeze(0) # (N, D) elif MODEL_NAME in ["sdxl"]: null_text_inputs = tokenizer( "", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) null_text_input_ids = null_text_inputs.input_ids # (1, N) null_prompt_embed = text_encoder(null_text_input_ids, output_hidden_states=True) null_prompt_embed_1 = null_prompt_embed.hidden_states[-2].squeeze(0) # (N, D1); `-2` because SDXL always indexes from the penultimate layer # Text encoder 2 null_text_inputs = tokenizer_2( "", padding="max_length", max_length=tokenizer_2.model_max_length, truncation=True, return_tensors="pt", ) null_text_input_ids = null_text_inputs.input_ids # (1, N) null_prompt_embed = text_encoder_2(null_text_input_ids, output_hidden_states=True) null_pooled_prompt_embed = null_prompt_embed.text_embeds.squeeze(0) # (D2) null_prompt_embed_2 = null_prompt_embed.hidden_states[-2].squeeze(0) # (N, D2); `-2` because SDXL always indexes from the penultimate layer null_prompt_embed = torch.cat([null_prompt_embed_1, null_prompt_embed_2], dim=1) # (N, D1+D2) elif MODEL_NAME in ["paa", "pas"]: max_length = {"paa": 120, "pas": 300} # hard-coded for PAA and PAS null_text_inputs = tokenizer( "", padding="max_length", max_length=max_length[MODEL_NAME], truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) null_text_input_ids = null_text_inputs.input_ids # (1, N) null_attention_mask = null_text_inputs.attention_mask # (1, N) null_prompt_embed = text_encoder(null_text_input_ids, attention_mask=null_attention_mask) null_prompt_embed = null_prompt_embed[0].squeeze(0) # (N, D) null_attention_mask = null_attention_mask.squeeze(0) # (N) elif MODEL_NAME in ["sd3m", "sd35m", "sd35l"]: null_text_inputs = tokenizer( "", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) null_text_input_ids = null_text_inputs.input_ids # (1, N) null_prompt_embed = text_encoder(null_text_input_ids, output_hidden_states=True) null_pooled_prompt_embed_1 = null_prompt_embed.text_embeds.squeeze(0) # (D1) null_prompt_embed_1 = null_prompt_embed.hidden_states[-2].squeeze(0) # (N, D1); `-2` because SD3(.5) always indexes from the penultimate layer # Text encoder 2 null_text_inputs = tokenizer_2( "", padding="max_length", max_length=tokenizer_2.model_max_length, truncation=True, return_tensors="pt", ) null_text_input_ids = null_text_inputs.input_ids # (1, N) null_prompt_embed = text_encoder_2(null_text_input_ids, output_hidden_states=True) null_pooled_prompt_embed_2 = null_prompt_embed.text_embeds.squeeze(0) # (D2) null_pooled_prompt_embed = torch.cat([null_pooled_prompt_embed_1, null_pooled_prompt_embed_2], dim=-1) # (D1+D2) null_prompt_embed_2 = null_prompt_embed.hidden_states[-2].squeeze(0) # (N, D2); `-2` because SD3(.5) always indexes from the penultimate layer null_clip_prompt_embed = torch.cat([null_prompt_embed_1, null_prompt_embed_2], dim=1) # (N, D1+D2) # Text encoder 3 null_text_inputs = tokenizer_3( "", padding="max_length", max_length=256, # hard-coded for SD3(.5) truncation=True, return_tensors="pt", ) null_text_input_ids = null_text_inputs.input_ids # (1, N3) null_prompt_embed = text_encoder_3(null_text_input_ids) null_prompt_embed_3 = null_prompt_embed[0].squeeze(0) # (N3, D3) null_clip_prompt_embed = F.pad( null_clip_prompt_embed, (0, null_prompt_embed_3.shape[-1] - null_clip_prompt_embed.shape[-1]), ) # (N, D3) null_prompt_embed = torch.cat([null_clip_prompt_embed, null_prompt_embed_3], dim=-2) # (N+N3, D3) DATASET_NAME = { "gobj265k": "GObjaverse", "gobj83k": "GObjaverse", }[dataset_name] dir = f"/tmp/{DATASET_NAME}_{MODEL_NAME}_prompt_embeds" os.makedirs(dir, exist_ok=True) np.save(f"{dir}/null.npy", null_prompt_embed.float().cpu().numpy()) if MODEL_NAME in ["sdxl", "sd3m", "sd35m", "sd35l"]: np.save(f"{dir}/null_pooled.npy", null_pooled_prompt_embed.float().cpu().numpy()) if MODEL_NAME in ["paa", "pas"]: np.save(f"{dir}/null_attention_mask.npy", null_attention_mask.float().cpu().numpy())