|
from cmath import log |
|
import torch |
|
from torch import nn |
|
|
|
import sys |
|
|
|
from ldm.data.personalized import per_img_token_list |
|
from transformers import CLIPTokenizer |
|
from functools import partial |
|
|
|
DEFAULT_PLACEHOLDER_TOKEN = ['*'] |
|
|
|
PROGRESSIVE_SCALE = 2000 |
|
|
|
|
|
def get_clip_token_for_string(tokenizer, string): |
|
batch_encoding = tokenizer( |
|
string, |
|
truncation=True, |
|
max_length=77, |
|
return_length=True, |
|
return_overflowing_tokens=False, |
|
padding='max_length', |
|
return_tensors='pt', |
|
) |
|
tokens = batch_encoding['input_ids'] |
|
""" assert ( |
|
torch.count_nonzero(tokens - 49407) == 2 |
|
), f"String '{string}' maps to more than a single token. Please use another string" """ |
|
|
|
return tokens[0, 1] |
|
|
|
|
|
def get_bert_token_for_string(tokenizer, string): |
|
token = tokenizer(string) |
|
|
|
|
|
token = token[0, 1] |
|
|
|
return token |
|
|
|
|
|
def get_embedding_for_clip_token(embedder, token): |
|
return embedder(token.unsqueeze(0))[0, 0] |
|
|
|
|
|
class EmbeddingManager(nn.Module): |
|
def __init__( |
|
self, |
|
embedder, |
|
placeholder_strings=None, |
|
initializer_words=None, |
|
per_image_tokens=False, |
|
num_vectors_per_token=1, |
|
progressive_words=False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.embedder = embedder |
|
|
|
self.string_to_token_dict = {} |
|
self.string_to_param_dict = nn.ParameterDict() |
|
|
|
self.initial_embeddings = ( |
|
nn.ParameterDict() |
|
) |
|
|
|
self.progressive_words = progressive_words |
|
self.progressive_counter = 0 |
|
|
|
self.max_vectors_per_token = num_vectors_per_token |
|
|
|
if hasattr( |
|
embedder, 'tokenizer' |
|
): |
|
self.is_clip = True |
|
get_token_for_string = partial( |
|
get_clip_token_for_string, embedder.tokenizer |
|
) |
|
get_embedding_for_tkn = partial( |
|
get_embedding_for_clip_token, |
|
embedder.transformer.text_model.embeddings, |
|
) |
|
token_dim = 1280 |
|
else: |
|
self.is_clip = False |
|
get_token_for_string = partial( |
|
get_bert_token_for_string, embedder.tknz_fn |
|
) |
|
get_embedding_for_tkn = embedder.transformer.token_emb |
|
token_dim = 1280 |
|
|
|
if per_image_tokens: |
|
placeholder_strings.extend(per_img_token_list) |
|
|
|
for idx, placeholder_string in enumerate(placeholder_strings): |
|
|
|
token = get_token_for_string(placeholder_string) |
|
|
|
if initializer_words and idx < len(initializer_words): |
|
init_word_token = get_token_for_string(initializer_words[idx]) |
|
|
|
with torch.no_grad(): |
|
init_word_embedding = get_embedding_for_tkn( |
|
init_word_token.cpu() |
|
) |
|
|
|
token_params = torch.nn.Parameter( |
|
init_word_embedding.unsqueeze(0).repeat( |
|
num_vectors_per_token, 1 |
|
), |
|
requires_grad=True, |
|
) |
|
self.initial_embeddings[ |
|
placeholder_string |
|
] = torch.nn.Parameter( |
|
init_word_embedding.unsqueeze(0).repeat( |
|
num_vectors_per_token, 1 |
|
), |
|
requires_grad=False, |
|
) |
|
else: |
|
token_params = torch.nn.Parameter( |
|
torch.rand( |
|
size=(num_vectors_per_token, token_dim), |
|
requires_grad=True, |
|
) |
|
) |
|
|
|
self.string_to_token_dict[placeholder_string] = token |
|
self.string_to_param_dict[placeholder_string] = token_params |
|
|
|
def forward( |
|
self, |
|
tokenized_text, |
|
embedded_text, |
|
): |
|
b, n, device = *tokenized_text.shape, tokenized_text.device |
|
|
|
for ( |
|
placeholder_string, |
|
placeholder_token, |
|
) in self.string_to_token_dict.items(): |
|
|
|
placeholder_embedding = self.string_to_param_dict[ |
|
placeholder_string |
|
].to(device) |
|
|
|
if ( |
|
self.max_vectors_per_token == 1 |
|
): |
|
placeholder_idx = torch.where( |
|
tokenized_text == placeholder_token.to(device) |
|
) |
|
embedded_text[placeholder_idx] = placeholder_embedding |
|
else: |
|
if self.progressive_words: |
|
self.progressive_counter += 1 |
|
max_step_tokens = ( |
|
1 + self.progressive_counter // PROGRESSIVE_SCALE |
|
) |
|
else: |
|
max_step_tokens = self.max_vectors_per_token |
|
|
|
num_vectors_for_token = min( |
|
placeholder_embedding.shape[0], max_step_tokens |
|
) |
|
|
|
placeholder_rows, placeholder_cols = torch.where( |
|
tokenized_text == placeholder_token.to(device) |
|
) |
|
|
|
if placeholder_rows.nelement() == 0: |
|
continue |
|
|
|
sorted_cols, sort_idx = torch.sort( |
|
placeholder_cols, descending=True |
|
) |
|
sorted_rows = placeholder_rows[sort_idx] |
|
|
|
for idx in range(len(sorted_rows)): |
|
row = sorted_rows[idx] |
|
col = sorted_cols[idx] |
|
|
|
new_token_row = torch.cat( |
|
[ |
|
tokenized_text[row][:col], |
|
placeholder_token.repeat(num_vectors_for_token).to( |
|
device |
|
), |
|
tokenized_text[row][col + 1 :], |
|
], |
|
axis=0, |
|
)[:n] |
|
new_embed_row = torch.cat( |
|
[ |
|
embedded_text[row][:col], |
|
placeholder_embedding[:num_vectors_for_token], |
|
embedded_text[row][col + 1 :], |
|
], |
|
axis=0, |
|
)[:n] |
|
|
|
embedded_text[row] = new_embed_row |
|
tokenized_text[row] = new_token_row |
|
|
|
return embedded_text |
|
|
|
def save(self, ckpt_path): |
|
torch.save( |
|
{ |
|
'string_to_token': self.string_to_token_dict, |
|
'string_to_param': self.string_to_param_dict, |
|
}, |
|
ckpt_path, |
|
) |
|
|
|
def load(self, ckpt_path): |
|
ckpt = torch.load(ckpt_path, map_location='cpu') |
|
|
|
self.string_to_token_dict = ckpt["string_to_token"] |
|
self.string_to_param_dict = ckpt["string_to_param"] |
|
|
|
|
|
def get_embedding_norms_squared(self): |
|
all_params = torch.cat( |
|
list(self.string_to_param_dict.values()), axis=0 |
|
) |
|
param_norm_squared = (all_params * all_params).sum( |
|
axis=-1 |
|
) |
|
|
|
return param_norm_squared |
|
|
|
def embedding_parameters(self): |
|
return self.string_to_param_dict.parameters() |
|
|
|
def embedding_to_coarse_loss(self): |
|
|
|
loss = 0.0 |
|
num_embeddings = len(self.initial_embeddings) |
|
|
|
for key in self.initial_embeddings: |
|
optimized = self.string_to_param_dict[key] |
|
coarse = self.initial_embeddings[key].clone().to(optimized.device) |
|
|
|
loss = ( |
|
loss |
|
+ (optimized - coarse) |
|
@ (optimized - coarse).T |
|
/ num_embeddings |
|
) |
|
|
|
return loss |
|
|