|
import os |
|
import json |
|
import torch |
|
import random |
|
import numpy as np |
|
|
|
COLORS = { |
|
'brown': [165, 42, 42], |
|
'red': [255, 0, 0], |
|
'pink': [253, 108, 158], |
|
'orange': [255, 165, 0], |
|
'yellow': [255, 255, 0], |
|
'purple': [128, 0, 128], |
|
'green': [0, 128, 0], |
|
'blue': [0, 0, 255], |
|
'white': [255, 255, 255], |
|
'gray': [128, 128, 128], |
|
'black': [0, 0, 0], |
|
} |
|
|
|
|
|
def seed_everything(seed): |
|
random.seed(seed) |
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
|
|
def hex_to_rgb(hex_string, return_nearest_color=False): |
|
r""" |
|
Covert Hex triplet to RGB triplet. |
|
""" |
|
|
|
hex_string = hex_string.lstrip('#') |
|
|
|
red = int(hex_string[0:2], 16) |
|
green = int(hex_string[2:4], 16) |
|
blue = int(hex_string[4:6], 16) |
|
rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255. |
|
if return_nearest_color: |
|
nearest_color = find_nearest_color(rgb) |
|
return rgb.cuda(), nearest_color |
|
return rgb.cuda() |
|
|
|
|
|
def find_nearest_color(rgb): |
|
r""" |
|
Find the nearest neighbor color given the RGB value. |
|
""" |
|
if isinstance(rgb, list) or isinstance(rgb, tuple): |
|
rgb = torch.FloatTensor(rgb)[None, :, None, None]/255. |
|
color_distance = torch.FloatTensor([np.linalg.norm( |
|
rgb - torch.FloatTensor(COLORS[color])[None, :, None, None]/255.) for color in COLORS.keys()]) |
|
nearest_color = list(COLORS.keys())[torch.argmin(color_distance).item()] |
|
return nearest_color |
|
|
|
|
|
def font2style(font): |
|
r""" |
|
Convert the font name to the style name. |
|
""" |
|
return {'mirza': 'Claud Monet, impressionism, oil on canvas', |
|
'roboto': 'Ukiyoe', |
|
'cursive': 'Cyber Punk, futuristic, blade runner, william gibson, trending on artstation hq', |
|
'sofia': 'Pop Art, masterpiece, andy warhol', |
|
'slabo': 'Vincent Van Gogh', |
|
'inconsolata': 'Pixel Art, 8 bits, 16 bits', |
|
'ubuntu': 'Rembrandt', |
|
'Monoton': 'neon art, colorful light, highly details, octane render', |
|
'Akronim': 'Abstract Cubism, Pablo Picasso', }[font] |
|
|
|
|
|
def parse_json(json_str): |
|
r""" |
|
Convert the JSON string to attributes. |
|
""" |
|
|
|
base_text_prompt = '' |
|
style_text_prompts = [] |
|
footnote_text_prompts = [] |
|
footnote_target_tokens = [] |
|
color_text_prompts = [] |
|
color_rgbs = [] |
|
color_names = [] |
|
size_text_prompts_and_sizes = [] |
|
|
|
|
|
prev_style = None |
|
prev_color_rgb = None |
|
use_grad_guidance = False |
|
for span in json_str['ops']: |
|
text_prompt = span['insert'].rstrip('\n') |
|
base_text_prompt += span['insert'].rstrip('\n') |
|
if text_prompt == ' ': |
|
continue |
|
if 'attributes' in span: |
|
if 'font' in span['attributes']: |
|
style = font2style(span['attributes']['font']) |
|
if prev_style == style: |
|
prev_text_prompt = style_text_prompts[-1].split('in the style of')[ |
|
0] |
|
style_text_prompts[-1] = prev_text_prompt + \ |
|
' ' + text_prompt + f' in the style of {style}' |
|
else: |
|
style_text_prompts.append( |
|
text_prompt + f' in the style of {style}') |
|
prev_style = style |
|
else: |
|
prev_style = None |
|
if 'link' in span['attributes']: |
|
footnote_text_prompts.append(span['attributes']['link']) |
|
footnote_target_tokens.append(text_prompt) |
|
font_size = 1 |
|
if 'size' in span['attributes'] and 'strike' not in span['attributes']: |
|
font_size = float(span['attributes']['size'][:-2])/3. |
|
elif 'size' in span['attributes'] and 'strike' in span['attributes']: |
|
font_size = -float(span['attributes']['size'][:-2])/3. |
|
elif 'size' not in span['attributes'] and 'strike' not in span['attributes']: |
|
font_size = 1 |
|
if 'color' in span['attributes']: |
|
use_grad_guidance = True |
|
color_rgb, nearest_color = hex_to_rgb( |
|
span['attributes']['color'], True) |
|
if prev_color_rgb == color_rgb: |
|
prev_text_prompt = color_text_prompts[-1] |
|
color_text_prompts[-1] = prev_text_prompt + \ |
|
' ' + text_prompt |
|
else: |
|
color_rgbs.append(color_rgb) |
|
color_names.append(nearest_color) |
|
color_text_prompts.append(text_prompt) |
|
if font_size != 1: |
|
size_text_prompts_and_sizes.append([text_prompt, font_size]) |
|
return base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\ |
|
color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance |
|
|
|
|
|
def get_region_diffusion_input(model, base_text_prompt, style_text_prompts, footnote_text_prompts, |
|
footnote_target_tokens, color_text_prompts, color_names): |
|
r""" |
|
Algorithm 1 in the paper. |
|
""" |
|
region_text_prompts = [] |
|
region_target_token_ids = [] |
|
base_tokens = model.tokenizer._tokenize(base_text_prompt) |
|
|
|
for text_prompt in style_text_prompts: |
|
region_text_prompts.append(text_prompt) |
|
region_target_token_ids.append([]) |
|
style_tokens = model.tokenizer._tokenize( |
|
text_prompt.split('in the style of')[0]) |
|
for style_token in style_tokens: |
|
region_target_token_ids[-1].append( |
|
base_tokens.index(style_token)+1) |
|
|
|
|
|
for footnote_text_prompt, text_prompt in zip(footnote_text_prompts, footnote_target_tokens): |
|
region_target_token_ids.append([]) |
|
region_text_prompts.append(footnote_text_prompt) |
|
style_tokens = model.tokenizer._tokenize(text_prompt) |
|
for style_token in style_tokens: |
|
region_target_token_ids[-1].append( |
|
base_tokens.index(style_token)+1) |
|
|
|
|
|
for color_text_prompt, color_name in zip(color_text_prompts, color_names): |
|
region_target_token_ids.append([]) |
|
region_text_prompts.append(color_name+' '+color_text_prompt) |
|
style_tokens = model.tokenizer._tokenize(color_text_prompt) |
|
for style_token in style_tokens: |
|
region_target_token_ids[-1].append( |
|
base_tokens.index(style_token)+1) |
|
|
|
|
|
region_text_prompts.append(base_text_prompt) |
|
region_target_token_ids_all = [ |
|
id for ids in region_target_token_ids for id in ids] |
|
target_token_ids_rest = [id for id in range( |
|
1, len(base_tokens)+1) if id not in region_target_token_ids_all] |
|
region_target_token_ids.append(target_token_ids_rest) |
|
|
|
region_target_token_ids = [torch.LongTensor( |
|
obj_token_id) for obj_token_id in region_target_token_ids] |
|
return region_text_prompts, region_target_token_ids, base_tokens |
|
|
|
|
|
def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes): |
|
r""" |
|
Control the token impact using font sizes. |
|
""" |
|
word_pos = [] |
|
font_sizes = [] |
|
for text_prompt, font_size in size_text_prompts_and_sizes: |
|
size_tokens = model.tokenizer._tokenize(text_prompt) |
|
for size_token in size_tokens: |
|
word_pos.append(base_tokens.index(size_token)+1) |
|
font_sizes.append(font_size) |
|
if len(word_pos) > 0: |
|
word_pos = torch.LongTensor(word_pos).cuda() |
|
font_sizes = torch.FloatTensor(font_sizes).cuda() |
|
else: |
|
word_pos = None |
|
font_sizes = None |
|
text_format_dict = { |
|
'word_pos': word_pos, |
|
'font_size': font_sizes, |
|
} |
|
return text_format_dict |
|
|
|
|
|
def get_gradient_guidance_input(model, base_tokens, color_text_prompts, color_rgbs, text_format_dict, |
|
guidance_start_step=999, color_guidance_weight=1): |
|
r""" |
|
Control the token impact using font sizes. |
|
""" |
|
color_target_token_ids = [] |
|
for text_prompt in color_text_prompts: |
|
color_target_token_ids.append([]) |
|
color_tokens = model.tokenizer._tokenize(text_prompt) |
|
for color_token in color_tokens: |
|
color_target_token_ids[-1].append(base_tokens.index(color_token)+1) |
|
color_target_token_ids_all = [ |
|
id for ids in color_target_token_ids for id in ids] |
|
color_target_token_ids_rest = [id for id in range( |
|
1, len(base_tokens)+1) if id not in color_target_token_ids_all] |
|
color_target_token_ids.append(color_target_token_ids_rest) |
|
color_target_token_ids = [torch.LongTensor( |
|
obj_token_id) for obj_token_id in color_target_token_ids] |
|
|
|
text_format_dict['target_RGB'] = color_rgbs |
|
text_format_dict['guidance_start_step'] = guidance_start_step |
|
text_format_dict['color_guidance_weight'] = color_guidance_weight |
|
return text_format_dict, color_target_token_ids |
|
|