Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
import os | |
import argparse | |
import json | |
from multiprocessing import Process | |
from tqdm import tqdm | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn.functional as F | |
from transformers import ( | |
CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, | |
T5EncoderModel, T5Tokenizer, T5TokenizerFast, | |
) | |
def text_encode( | |
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], tokenizer: CLIPTokenizer, | |
oids: List[str], gpu_id: int, | |
text_encoder_2: Optional[Union[CLIPTextModelWithProjection, T5EncoderModel]] = None, | |
tokenizer_2: Optional[Union[CLIPTokenizer, T5Tokenizer]] = None, | |
text_encoder_3: Optional[T5EncoderModel] = None, | |
tokenizer_3: Optional[T5TokenizerFast] = None, | |
): | |
global caption_dict, MODEL_NAME, BATCH, use_special_words, dataset_name | |
device = f"cuda:{gpu_id}" | |
text_encoder = text_encoder.to(device) | |
if MODEL_NAME in ["sdxl", "sd3m", "sd35m", "sd35l"]: | |
assert text_encoder_2 is not None and tokenizer_2 is not None | |
text_encoder_2 = text_encoder_2.to(device) | |
if MODEL_NAME in ["sd3m", "sd35m", "sd35l"]: | |
assert text_encoder_3 is not None and tokenizer_3 is not None | |
text_encoder_3 = text_encoder_3.to(device) | |
for i in tqdm(range(0, len(oids), BATCH), desc=pretrained_model_name_or_path, ncols=125): | |
batch_oids = oids[i:min(i+BATCH, len(oids))] | |
batch_captions = [ | |
"3d asset in the sks style: " if use_special_words else "" + | |
caption_dict[oid] | |
for oid in batch_oids | |
] | |
if MODEL_NAME in ["sd15", "sd21"]: | |
batch_text_inputs = tokenizer( | |
batch_captions, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) | |
batch_prompt_embeds = text_encoder(batch_text_input_ids) | |
batch_prompt_embeds = batch_prompt_embeds[0] # (B, N, D) | |
elif MODEL_NAME in ["sdxl"]: | |
batch_text_inputs = tokenizer( | |
batch_captions, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) | |
batch_prompt_embeds = text_encoder(batch_text_input_ids, output_hidden_states=True) | |
batch_prompt_embeds_1 = batch_prompt_embeds.hidden_states[-2] # (B, N, D1); `-2` because SDXL always indexes from the penultimate layer | |
# Text encoder 2 | |
batch_text_inputs = tokenizer_2( | |
batch_captions, | |
padding="max_length", | |
max_length=tokenizer_2.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) | |
batch_prompt_embeds = text_encoder_2(batch_text_input_ids, output_hidden_states=True) | |
batch_pooled_prompt_embeds = batch_prompt_embeds.text_embeds # (B, D2) | |
batch_prompt_embeds_2 = batch_prompt_embeds.hidden_states[-2] # (B, N, D); `-2` because SDXL always indexes from the penultimate layer | |
batch_prompt_embeds = torch.cat([batch_prompt_embeds_1, batch_prompt_embeds_2], dim=-1) # (B, N, D1+D2) | |
elif MODEL_NAME in ["paa", "pas"]: | |
max_length = {"paa": 120, "pas": 300} # hard-coded for PAA and PAS | |
batch_captions = [t.lower().strip() for t in batch_captions] | |
batch_text_inputs = tokenizer( | |
batch_captions, | |
padding="max_length", | |
max_length=max_length[MODEL_NAME], | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) | |
batch_prompt_attention_mask = batch_text_inputs.attention_mask.to(device) # (B, N) | |
batch_prompt_embeds = text_encoder(batch_text_input_ids, attention_mask=batch_prompt_attention_mask) | |
batch_prompt_embeds = batch_prompt_embeds[0] # (B, N, D) | |
elif MODEL_NAME in ["sd3m", "sd35m", "sd35l"]: | |
batch_text_inputs = tokenizer( | |
batch_captions, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) | |
batch_prompt_embeds = text_encoder(batch_text_input_ids, output_hidden_states=True) | |
batch_pooled_prompt_embeds_1 = batch_prompt_embeds.text_embeds # (B, D) | |
batch_prompt_embeds_1 = batch_prompt_embeds.hidden_states[-2] # (B, N, D); `-2` because SD3(.5) always indexes from the penultimate layer | |
# Text encoder 2 | |
batch_text_inputs = tokenizer_2( | |
batch_captions, | |
padding="max_length", | |
max_length=tokenizer_2.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N) | |
batch_prompt_embeds = text_encoder_2(batch_text_input_ids, output_hidden_states=True) | |
batch_pooled_prompt_embeds_2 = batch_prompt_embeds.text_embeds # (B, D) | |
batch_pooled_prompt_embeds = torch.cat([batch_pooled_prompt_embeds_1, batch_pooled_prompt_embeds_2], dim=-1) # (B, D1+D2) | |
batch_prompt_embeds_2 = batch_prompt_embeds.hidden_states[-2] # (B, N, D); `-2` because SD3(.5) always indexes from the penultimate layer | |
batch_clip_prompt_embeds = torch.cat([batch_prompt_embeds_1, batch_prompt_embeds_2], dim=-1) # (B, N, D1+D2) | |
# Text encoder 3 | |
batch_text_inputs = tokenizer_3( | |
batch_captions, | |
padding="max_length", | |
max_length=256, # hard-coded for SD3(.5) | |
truncation=True, | |
return_tensors="pt", | |
) | |
batch_text_input_ids = batch_text_inputs.input_ids.to(device) # (B, N3) | |
batch_prompt_embeds = text_encoder_3(batch_text_input_ids) | |
batch_prompt_embeds_3 = batch_prompt_embeds[0] # (B, N3, D3) | |
batch_clip_prompt_embeds = F.pad( | |
batch_clip_prompt_embeds, | |
(0, batch_prompt_embeds_3.shape[-1] - batch_clip_prompt_embeds.shape[-1]), | |
) # (B, N, D3) | |
batch_prompt_embeds = torch.cat([batch_clip_prompt_embeds, batch_prompt_embeds_3], dim=-2) # (B, N+N3, D3) | |
DATASET_NAME = { | |
"gobj265k": "GObjaverse", | |
"gobj83k": "GObjaverse", | |
}[dataset_name] | |
dir = f"/tmp/{DATASET_NAME}_{MODEL_NAME}_prompt_embeds" | |
os.makedirs(dir, exist_ok=True) | |
for j, oid in enumerate(batch_oids): | |
np.save(f"{dir}/{oid}.npy", batch_prompt_embeds[j].float().cpu().numpy()) | |
if MODEL_NAME in ["sdxl", "sd3m", "sd35m", "sd35l"]: | |
np.save(f"{dir}/{oid}_pooled.npy", batch_pooled_prompt_embeds[j].float().cpu().numpy()) | |
if MODEL_NAME in ["paa", "pas"]: | |
np.save(f"{dir}/{oid}_attention_mask.npy", batch_prompt_attention_mask[j].float().cpu().numpy()) | |
if __name__ == "__main__": | |
args = argparse.ArgumentParser("Encode prompt embeddings") | |
args.add_argument("model_name", type=str, choices=["sd15", "sd21", "sdxl", "paa", "pas", "sd3m", "sd35m", "sd35l"]) | |
args.add_argument("--batch_size", type=int, default=128) | |
args.add_argument("--dataset_name", default="gobj83k", choices=["gobj265k", "gobj83k"]) | |
args.add_argument("--use_special_words", action="store_true") | |
args = args.parse_args() | |
MODEL_NAME = args.model_name | |
pretrained_model_name_or_path = { | |
"sd15": "chenguolin/stable-diffusion-v1-5", | |
"sd21": "stabilityai/stable-diffusion-2-1", | |
"sdxl": "stabilityai/stable-diffusion-xl-base-1.0", | |
"paa": "PixArt-alpha/PixArt-XL-2-512x512", # "PixArt-alpha/PixArt-XL-2-1024-MS" | |
"pas": "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", | |
"sd3m": "stabilityai/stable-diffusion-3-medium-diffusers", | |
"sd35m": "stabilityai/stable-diffusion-3.5-medium", | |
"sd35l": "stabilityai/stable-diffusion-3.5-large", | |
}[MODEL_NAME] | |
NUM_GPU = torch.cuda.device_count() | |
BATCH = args.batch_size | |
dataset_name = args.dataset_name | |
use_special_words = args.use_special_words | |
variant = "fp16" if MODEL_NAME not in ["pas", "sd3m", "sd35m", "sd35l"] else None # text encoders of PAS and SD3(.5) are already in fp16 | |
if MODEL_NAME in ["sd15", "sdxl"]: | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", variant=variant) | |
elif MODEL_NAME in ["paa", "pas"]: | |
tokenizer = T5Tokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") | |
text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", variant=variant) | |
elif MODEL_NAME in ["sd3m", "sd35m", "sd35l"]: | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModelWithProjection.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", variant=variant) | |
if MODEL_NAME in ["sdxl", "sd3m", "sd35m", "sd35l"]: | |
tokenizer_2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2") | |
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder_2", variant=variant) | |
else: | |
tokenizer_2 = None | |
text_encoder_2 = None | |
if MODEL_NAME in ["sd3m", "sd35m", "sd35l"]: | |
tokenizer_3 = T5TokenizerFast.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_3") | |
text_encoder_3 = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder_3", variant=variant) | |
else: | |
tokenizer_3 = None | |
text_encoder_3 = None | |
# GObjaverse Cap3D | |
if "gobj" in dataset_name: | |
if not os.path.exists("extensions/assets/Cap3D_automated_Objaverse_full.csv"): | |
os.system("wget https://huggingface.co/datasets/tiange/Cap3D/resolve/main/Cap3D_automated_Objaverse_full.csv -P extensions/assets/") | |
captions = pd.read_csv("extensions/assets/Cap3D_automated_Objaverse_full.csv", header=None) | |
caption_dict = {} | |
for i in tqdm(range(len(captions)), desc="Preparing caption dict", ncols=125): | |
caption_dict[captions.iloc[i][0]] = captions.iloc[i][1] | |
if not os.path.exists("extensions/assets/gobjaverse_280k_index_to_objaverse.json"): | |
os.system("wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/gobjaverse_280k_index_to_objaverse.json -P extensions/assets/") | |
gids_to_oids = json.load(open("extensions/assets/gobjaverse_280k_index_to_objaverse.json", "r")) | |
if dataset_name == "gobj83k": # 83k subset | |
if not os.path.exists("extensions/assets/gobj_merged.json"): | |
os.system("wget https://raw.githubusercontent.com/ashawkey/objaverse_filter/main/gobj_merged.json -P extensions/assets/") | |
gids = json.load(open("extensions/assets/gobj_merged.json", "r")) | |
all_oids = [gids_to_oids[gid].split("/")[1].split(".")[0] for gid in gids] | |
elif dataset_name == "gobj265k": # GObjaverse all 265k | |
all_oids = [oid.split("/")[1].split(".")[0] for oid in gids_to_oids.values()] | |
assert all(oid in caption_dict.keys() for oid in all_oids) | |
oids_split = np.array_split(all_oids, NUM_GPU) | |
processes = [ | |
Process( | |
target=text_encode, | |
args=(text_encoder, tokenizer, oids_split[i], i, text_encoder_2, tokenizer_2, text_encoder_3, tokenizer_3), | |
) | |
for i in range(NUM_GPU) | |
] | |
for p in processes: | |
p.start() | |
for p in processes: | |
p.join() | |
with torch.no_grad(): | |
if MODEL_NAME in ["sd15", "sd21"]: | |
null_text_inputs = tokenizer( | |
"", | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
null_text_input_ids = null_text_inputs.input_ids # (1, N) | |
null_prompt_embed = text_encoder(null_text_input_ids) | |
null_prompt_embed = null_prompt_embed[0].squeeze(0) # (N, D) | |
elif MODEL_NAME in ["sdxl"]: | |
null_text_inputs = tokenizer( | |
"", | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
null_text_input_ids = null_text_inputs.input_ids # (1, N) | |
null_prompt_embed = text_encoder(null_text_input_ids, output_hidden_states=True) | |
null_prompt_embed_1 = null_prompt_embed.hidden_states[-2].squeeze(0) # (N, D1); `-2` because SDXL always indexes from the penultimate layer | |
# Text encoder 2 | |
null_text_inputs = tokenizer_2( | |
"", | |
padding="max_length", | |
max_length=tokenizer_2.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
null_text_input_ids = null_text_inputs.input_ids # (1, N) | |
null_prompt_embed = text_encoder_2(null_text_input_ids, output_hidden_states=True) | |
null_pooled_prompt_embed = null_prompt_embed.text_embeds.squeeze(0) # (D2) | |
null_prompt_embed_2 = null_prompt_embed.hidden_states[-2].squeeze(0) # (N, D2); `-2` because SDXL always indexes from the penultimate layer | |
null_prompt_embed = torch.cat([null_prompt_embed_1, null_prompt_embed_2], dim=1) # (N, D1+D2) | |
elif MODEL_NAME in ["paa", "pas"]: | |
max_length = {"paa": 120, "pas": 300} # hard-coded for PAA and PAS | |
null_text_inputs = tokenizer( | |
"", | |
padding="max_length", | |
max_length=max_length[MODEL_NAME], | |
truncation=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
null_text_input_ids = null_text_inputs.input_ids # (1, N) | |
null_attention_mask = null_text_inputs.attention_mask # (1, N) | |
null_prompt_embed = text_encoder(null_text_input_ids, attention_mask=null_attention_mask) | |
null_prompt_embed = null_prompt_embed[0].squeeze(0) # (N, D) | |
null_attention_mask = null_attention_mask.squeeze(0) # (N) | |
elif MODEL_NAME in ["sd3m", "sd35m", "sd35l"]: | |
null_text_inputs = tokenizer( | |
"", | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
null_text_input_ids = null_text_inputs.input_ids # (1, N) | |
null_prompt_embed = text_encoder(null_text_input_ids, output_hidden_states=True) | |
null_pooled_prompt_embed_1 = null_prompt_embed.text_embeds.squeeze(0) # (D1) | |
null_prompt_embed_1 = null_prompt_embed.hidden_states[-2].squeeze(0) # (N, D1); `-2` because SD3(.5) always indexes from the penultimate layer | |
# Text encoder 2 | |
null_text_inputs = tokenizer_2( | |
"", | |
padding="max_length", | |
max_length=tokenizer_2.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
null_text_input_ids = null_text_inputs.input_ids # (1, N) | |
null_prompt_embed = text_encoder_2(null_text_input_ids, output_hidden_states=True) | |
null_pooled_prompt_embed_2 = null_prompt_embed.text_embeds.squeeze(0) # (D2) | |
null_pooled_prompt_embed = torch.cat([null_pooled_prompt_embed_1, null_pooled_prompt_embed_2], dim=-1) # (D1+D2) | |
null_prompt_embed_2 = null_prompt_embed.hidden_states[-2].squeeze(0) # (N, D2); `-2` because SD3(.5) always indexes from the penultimate layer | |
null_clip_prompt_embed = torch.cat([null_prompt_embed_1, null_prompt_embed_2], dim=1) # (N, D1+D2) | |
# Text encoder 3 | |
null_text_inputs = tokenizer_3( | |
"", | |
padding="max_length", | |
max_length=256, # hard-coded for SD3(.5) | |
truncation=True, | |
return_tensors="pt", | |
) | |
null_text_input_ids = null_text_inputs.input_ids # (1, N3) | |
null_prompt_embed = text_encoder_3(null_text_input_ids) | |
null_prompt_embed_3 = null_prompt_embed[0].squeeze(0) # (N3, D3) | |
null_clip_prompt_embed = F.pad( | |
null_clip_prompt_embed, | |
(0, null_prompt_embed_3.shape[-1] - null_clip_prompt_embed.shape[-1]), | |
) # (N, D3) | |
null_prompt_embed = torch.cat([null_clip_prompt_embed, null_prompt_embed_3], dim=-2) # (N+N3, D3) | |
DATASET_NAME = { | |
"gobj265k": "GObjaverse", | |
"gobj83k": "GObjaverse", | |
}[dataset_name] | |
dir = f"/tmp/{DATASET_NAME}_{MODEL_NAME}_prompt_embeds" | |
os.makedirs(dir, exist_ok=True) | |
np.save(f"{dir}/null.npy", null_prompt_embed.float().cpu().numpy()) | |
if MODEL_NAME in ["sdxl", "sd3m", "sd35m", "sd35l"]: | |
np.save(f"{dir}/null_pooled.npy", null_pooled_prompt_embed.float().cpu().numpy()) | |
if MODEL_NAME in ["paa", "pas"]: | |
np.save(f"{dir}/null_attention_mask.npy", null_attention_mask.float().cpu().numpy()) | |