PoseModifier / UniAnimate /tools /modules /embedding_manager.py
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import torch
from torch import nn
import torch.nn.functional as F
import open_clip
from functools import partial
from utils.registry_class import EMBEDMANAGER
DEFAULT_PLACEHOLDER_TOKEN = ["*"]
PROGRESSIVE_SCALE = 2000
per_img_token_list = [
'讗', '讘', '讙', '讚', '讛', '讜', '讝', '讞', '讟', '讬', '讻', '诇', '诪', '谞', '住', '注', '驻', '爪', '拽', '专', '砖', '转',
]
def get_clip_token_for_string(string):
tokens = open_clip.tokenize(string)
return tokens[0, 1]
def get_embedding_for_clip_token(embedder, token):
return embedder(token.unsqueeze(0))[0]
@EMBEDMANAGER.register_class()
class EmbeddingManager(nn.Module):
def __init__(
self,
embedder,
placeholder_strings=None,
initializer_words=None,
per_image_tokens=False,
num_vectors_per_token=1,
progressive_words=False,
temporal_prompt_length=1,
token_dim=1024,
**kwargs
):
super().__init__()
self.string_to_token_dict = {}
self.string_to_param_dict = nn.ParameterDict()
self.initial_embeddings = nn.ParameterDict() # These should not be optimized
self.progressive_words = progressive_words
self.progressive_counter = 0
self.max_vectors_per_token = num_vectors_per_token
get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.model.token_embedding.cpu())
if per_image_tokens:
placeholder_strings.extend(per_img_token_list)
for idx, placeholder_string in enumerate(placeholder_strings):
token = get_clip_token_for_string(placeholder_string)
if initializer_words and idx < len(initializer_words):
init_word_token = get_clip_token_for_string(initializer_words[idx])
with torch.no_grad():
init_word_embedding = get_embedding_for_tkn(init_word_token)
token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)
else:
token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
self.string_to_token_dict[placeholder_string] = token
self.string_to_param_dict[placeholder_string] = token_params
def forward(
self,
tokenized_text,
embedded_text,
):
b, n, device = *tokenized_text.shape, tokenized_text.device
for placeholder_string, placeholder_token in self.string_to_token_dict.items():
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
embedded_text[placeholder_idx] = placeholder_embedding
else: # otherwise, need to insert and keep track of changing indices
if self.progressive_words:
self.progressive_counter += 1
max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
else:
max_step_tokens = self.max_vectors_per_token
num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)
placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))
if placeholder_rows.nelement() == 0:
continue
sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
sorted_rows = placeholder_rows[sort_idx]
for idx in range(len(sorted_rows)):
row = sorted_rows[idx]
col = sorted_cols[idx]
new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]
embedded_text[row] = new_embed_row
tokenized_text[row] = new_token_row
return embedded_text
def forward_with_text_img(
self,
tokenized_text,
embedded_text,
embedded_img,
):
device = tokenized_text.device
for placeholder_string, placeholder_token in self.string_to_token_dict.items():
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + embedded_img + placeholder_embedding
return embedded_text
def forward_with_text(
self,
tokenized_text,
embedded_text
):
device = tokenized_text.device
for placeholder_string, placeholder_token in self.string_to_token_dict.items():
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + placeholder_embedding
return embedded_text
def save(self, ckpt_path):
torch.save({"string_to_token": self.string_to_token_dict,
"string_to_param": self.string_to_param_dict}, ckpt_path)
def load(self, ckpt_path):
ckpt = torch.load(ckpt_path, map_location='cpu')
string_to_token = ckpt["string_to_token"]
string_to_param = ckpt["string_to_param"]
for string, token in string_to_token.items():
self.string_to_token_dict[string] = token
for string, param in string_to_param.items():
self.string_to_param_dict[string] = param
def get_embedding_norms_squared(self):
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders
return param_norm_squared
def embedding_parameters(self):
return self.string_to_param_dict.parameters()
def embedding_to_coarse_loss(self):
loss = 0.
num_embeddings = len(self.initial_embeddings)
for key in self.initial_embeddings:
optimized = self.string_to_param_dict[key]
coarse = self.initial_embeddings[key].clone().to(optimized.device)
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
return loss