|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image, ImageDraw, ImageFont |
|
import cv2 |
|
from typing import Optional, Union, Tuple, List, Callable, Dict |
|
|
|
|
|
|
|
def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, |
|
word_inds: Optional[torch.Tensor] = None): |
|
if type(bounds) is float: |
|
bounds = 0, bounds |
|
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) |
|
if word_inds is None: |
|
word_inds = torch.arange(alpha.shape[2]) |
|
alpha[: start, prompt_ind, word_inds] = 0 |
|
alpha[start: end, prompt_ind, word_inds] = 1 |
|
alpha[end:, prompt_ind, word_inds] = 0 |
|
return alpha |
|
|
|
def get_word_inds(text: str, word_place: int, tokenizer): |
|
split_text = text.split(" ") |
|
if type(word_place) is str: |
|
word_place = [i for i, word in enumerate(split_text) if word_place == word] |
|
elif type(word_place) is int: |
|
word_place = [word_place] |
|
out = [] |
|
if len(word_place) > 0: |
|
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] |
|
cur_len, ptr = 0, 0 |
|
|
|
for i in range(len(words_encode)): |
|
cur_len += len(words_encode[i]) |
|
if ptr in word_place: |
|
out.append(i + 1) |
|
if cur_len >= len(split_text[ptr]): |
|
ptr += 1 |
|
cur_len = 0 |
|
return np.array(out) |
|
|
|
def get_time_words_attention_alpha(prompts, num_steps, |
|
cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], |
|
tokenizer, max_num_words=77): |
|
if type(cross_replace_steps) is not dict: |
|
cross_replace_steps = {"default_": cross_replace_steps} |
|
if "default_" not in cross_replace_steps: |
|
cross_replace_steps["default_"] = (0., 1.) |
|
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) |
|
for i in range(len(prompts) - 1): |
|
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], |
|
i) |
|
for key, item in cross_replace_steps.items(): |
|
if key != "default_": |
|
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] |
|
for i, ind in enumerate(inds): |
|
if len(ind) > 0: |
|
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) |
|
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) |
|
return alpha_time_words |
|
|
|
|