Stand-In / prompters /omost.py
fffiloni's picture
Migrated from GitHub
26557da verified
from transformers import AutoTokenizer, TextIteratorStreamer
import difflib
import torch
import numpy as np
import re
from models.model_manager import ModelManager
from PIL import Image
valid_colors = { # r, g, b
"aliceblue": (240, 248, 255),
"antiquewhite": (250, 235, 215),
"aqua": (0, 255, 255),
"aquamarine": (127, 255, 212),
"azure": (240, 255, 255),
"beige": (245, 245, 220),
"bisque": (255, 228, 196),
"black": (0, 0, 0),
"blanchedalmond": (255, 235, 205),
"blue": (0, 0, 255),
"blueviolet": (138, 43, 226),
"brown": (165, 42, 42),
"burlywood": (222, 184, 135),
"cadetblue": (95, 158, 160),
"chartreuse": (127, 255, 0),
"chocolate": (210, 105, 30),
"coral": (255, 127, 80),
"cornflowerblue": (100, 149, 237),
"cornsilk": (255, 248, 220),
"crimson": (220, 20, 60),
"cyan": (0, 255, 255),
"darkblue": (0, 0, 139),
"darkcyan": (0, 139, 139),
"darkgoldenrod": (184, 134, 11),
"darkgray": (169, 169, 169),
"darkgrey": (169, 169, 169),
"darkgreen": (0, 100, 0),
"darkkhaki": (189, 183, 107),
"darkmagenta": (139, 0, 139),
"darkolivegreen": (85, 107, 47),
"darkorange": (255, 140, 0),
"darkorchid": (153, 50, 204),
"darkred": (139, 0, 0),
"darksalmon": (233, 150, 122),
"darkseagreen": (143, 188, 143),
"darkslateblue": (72, 61, 139),
"darkslategray": (47, 79, 79),
"darkslategrey": (47, 79, 79),
"darkturquoise": (0, 206, 209),
"darkviolet": (148, 0, 211),
"deeppink": (255, 20, 147),
"deepskyblue": (0, 191, 255),
"dimgray": (105, 105, 105),
"dimgrey": (105, 105, 105),
"dodgerblue": (30, 144, 255),
"firebrick": (178, 34, 34),
"floralwhite": (255, 250, 240),
"forestgreen": (34, 139, 34),
"fuchsia": (255, 0, 255),
"gainsboro": (220, 220, 220),
"ghostwhite": (248, 248, 255),
"gold": (255, 215, 0),
"goldenrod": (218, 165, 32),
"gray": (128, 128, 128),
"grey": (128, 128, 128),
"green": (0, 128, 0),
"greenyellow": (173, 255, 47),
"honeydew": (240, 255, 240),
"hotpink": (255, 105, 180),
"indianred": (205, 92, 92),
"indigo": (75, 0, 130),
"ivory": (255, 255, 240),
"khaki": (240, 230, 140),
"lavender": (230, 230, 250),
"lavenderblush": (255, 240, 245),
"lawngreen": (124, 252, 0),
"lemonchiffon": (255, 250, 205),
"lightblue": (173, 216, 230),
"lightcoral": (240, 128, 128),
"lightcyan": (224, 255, 255),
"lightgoldenrodyellow": (250, 250, 210),
"lightgray": (211, 211, 211),
"lightgrey": (211, 211, 211),
"lightgreen": (144, 238, 144),
"lightpink": (255, 182, 193),
"lightsalmon": (255, 160, 122),
"lightseagreen": (32, 178, 170),
"lightskyblue": (135, 206, 250),
"lightslategray": (119, 136, 153),
"lightslategrey": (119, 136, 153),
"lightsteelblue": (176, 196, 222),
"lightyellow": (255, 255, 224),
"lime": (0, 255, 0),
"limegreen": (50, 205, 50),
"linen": (250, 240, 230),
"magenta": (255, 0, 255),
"maroon": (128, 0, 0),
"mediumaquamarine": (102, 205, 170),
"mediumblue": (0, 0, 205),
"mediumorchid": (186, 85, 211),
"mediumpurple": (147, 112, 219),
"mediumseagreen": (60, 179, 113),
"mediumslateblue": (123, 104, 238),
"mediumspringgreen": (0, 250, 154),
"mediumturquoise": (72, 209, 204),
"mediumvioletred": (199, 21, 133),
"midnightblue": (25, 25, 112),
"mintcream": (245, 255, 250),
"mistyrose": (255, 228, 225),
"moccasin": (255, 228, 181),
"navajowhite": (255, 222, 173),
"navy": (0, 0, 128),
"navyblue": (0, 0, 128),
"oldlace": (253, 245, 230),
"olive": (128, 128, 0),
"olivedrab": (107, 142, 35),
"orange": (255, 165, 0),
"orangered": (255, 69, 0),
"orchid": (218, 112, 214),
"palegoldenrod": (238, 232, 170),
"palegreen": (152, 251, 152),
"paleturquoise": (175, 238, 238),
"palevioletred": (219, 112, 147),
"papayawhip": (255, 239, 213),
"peachpuff": (255, 218, 185),
"peru": (205, 133, 63),
"pink": (255, 192, 203),
"plum": (221, 160, 221),
"powderblue": (176, 224, 230),
"purple": (128, 0, 128),
"rebeccapurple": (102, 51, 153),
"red": (255, 0, 0),
"rosybrown": (188, 143, 143),
"royalblue": (65, 105, 225),
"saddlebrown": (139, 69, 19),
"salmon": (250, 128, 114),
"sandybrown": (244, 164, 96),
"seagreen": (46, 139, 87),
"seashell": (255, 245, 238),
"sienna": (160, 82, 45),
"silver": (192, 192, 192),
"skyblue": (135, 206, 235),
"slateblue": (106, 90, 205),
"slategray": (112, 128, 144),
"slategrey": (112, 128, 144),
"snow": (255, 250, 250),
"springgreen": (0, 255, 127),
"steelblue": (70, 130, 180),
"tan": (210, 180, 140),
"teal": (0, 128, 128),
"thistle": (216, 191, 216),
"tomato": (255, 99, 71),
"turquoise": (64, 224, 208),
"violet": (238, 130, 238),
"wheat": (245, 222, 179),
"white": (255, 255, 255),
"whitesmoke": (245, 245, 245),
"yellow": (255, 255, 0),
"yellowgreen": (154, 205, 50),
}
valid_locations = { # x, y in 90*90
"in the center": (45, 45),
"on the left": (15, 45),
"on the right": (75, 45),
"on the top": (45, 15),
"on the bottom": (45, 75),
"on the top-left": (15, 15),
"on the top-right": (75, 15),
"on the bottom-left": (15, 75),
"on the bottom-right": (75, 75),
}
valid_offsets = { # x, y in 90*90
"no offset": (0, 0),
"slightly to the left": (-10, 0),
"slightly to the right": (10, 0),
"slightly to the upper": (0, -10),
"slightly to the lower": (0, 10),
"slightly to the upper-left": (-10, -10),
"slightly to the upper-right": (10, -10),
"slightly to the lower-left": (-10, 10),
"slightly to the lower-right": (10, 10),
}
valid_areas = { # w, h in 90*90
"a small square area": (50, 50),
"a small vertical area": (40, 60),
"a small horizontal area": (60, 40),
"a medium-sized square area": (60, 60),
"a medium-sized vertical area": (50, 80),
"a medium-sized horizontal area": (80, 50),
"a large square area": (70, 70),
"a large vertical area": (60, 90),
"a large horizontal area": (90, 60),
}
def safe_str(x):
return x.strip(",. ") + "."
def closest_name(input_str, options):
input_str = input_str.lower()
closest_match = difflib.get_close_matches(
input_str, list(options.keys()), n=1, cutoff=0.5
)
assert isinstance(closest_match, list) and len(closest_match) > 0, (
f"The value [{input_str}] is not valid!"
)
result = closest_match[0]
if result != input_str:
print(f"Automatically corrected [{input_str}] -> [{result}].")
return result
class Canvas:
@staticmethod
def from_bot_response(response: str):
matched = re.search(r"```python\n(.*?)\n```", response, re.DOTALL)
assert matched, "Response does not contain codes!"
code_content = matched.group(1)
assert "canvas = Canvas()" in code_content, (
"Code block must include valid canvas var!"
)
local_vars = {"Canvas": Canvas}
exec(code_content, {}, local_vars)
canvas = local_vars.get("canvas", None)
assert isinstance(canvas, Canvas), "Code block must produce valid canvas var!"
return canvas
def __init__(self):
self.components = []
self.color = None
self.record_tags = True
self.prefixes = []
self.suffixes = []
return
def set_global_description(
self,
description: str,
detailed_descriptions: list,
tags: str,
HTML_web_color_name: str,
):
assert isinstance(description, str), "Global description is not valid!"
assert isinstance(detailed_descriptions, list) and all(
isinstance(item, str) for item in detailed_descriptions
), "Global detailed_descriptions is not valid!"
assert isinstance(tags, str), "Global tags is not valid!"
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
self.prefixes = [description]
self.suffixes = detailed_descriptions
if self.record_tags:
self.suffixes = self.suffixes + [tags]
self.prefixes = [safe_str(x) for x in self.prefixes]
self.suffixes = [safe_str(x) for x in self.suffixes]
return
def add_local_description(
self,
location: str,
offset: str,
area: str,
distance_to_viewer: float,
description: str,
detailed_descriptions: list,
tags: str,
atmosphere: str,
style: str,
quality_meta: str,
HTML_web_color_name: str,
):
assert isinstance(description, str), "Local description is wrong!"
assert (
isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0
), f"The distance_to_viewer for [{description}] is not positive float number!"
assert isinstance(detailed_descriptions, list) and all(
isinstance(item, str) for item in detailed_descriptions
), f"The detailed_descriptions for [{description}] is not valid!"
assert isinstance(tags, str), f"The tags for [{description}] is not valid!"
assert isinstance(atmosphere, str), (
f"The atmosphere for [{description}] is not valid!"
)
assert isinstance(style, str), f"The style for [{description}] is not valid!"
assert isinstance(quality_meta, str), (
f"The quality_meta for [{description}] is not valid!"
)
location = closest_name(location, valid_locations)
offset = closest_name(offset, valid_offsets)
area = closest_name(area, valid_areas)
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
xb, yb = valid_locations[location]
xo, yo = valid_offsets[offset]
w, h = valid_areas[area]
rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
rect = [max(0, min(90, i)) for i in rect]
color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
prefixes = self.prefixes + [description]
suffixes = detailed_descriptions
if self.record_tags:
suffixes = suffixes + [tags, atmosphere, style, quality_meta]
prefixes = [safe_str(x) for x in prefixes]
suffixes = [safe_str(x) for x in suffixes]
self.components.append(
dict(
rect=rect,
distance_to_viewer=distance_to_viewer,
color=color,
prefixes=prefixes,
suffixes=suffixes,
location=location,
)
)
return
def process(self):
# sort components
self.components = sorted(
self.components, key=lambda x: x["distance_to_viewer"], reverse=True
)
# compute initial latent
# print(self.color)
initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
for component in self.components:
a, b, c, d = component["rect"]
initial_latent[a:b, c:d] = (
0.7 * component["color"] + 0.3 * initial_latent[a:b, c:d]
)
initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
# compute conditions
bag_of_conditions = [
dict(
mask=np.ones(shape=(90, 90), dtype=np.float32),
prefixes=self.prefixes,
suffixes=self.suffixes,
location="full",
)
]
for i, component in enumerate(self.components):
a, b, c, d = component["rect"]
m = np.zeros(shape=(90, 90), dtype=np.float32)
m[a:b, c:d] = 1.0
bag_of_conditions.append(
dict(
mask=m,
prefixes=component["prefixes"],
suffixes=component["suffixes"],
location=component["location"],
)
)
return dict(
initial_latent=initial_latent,
bag_of_conditions=bag_of_conditions,
)
class OmostPromter(torch.nn.Module):
def __init__(self, model=None, tokenizer=None, template="", device="cpu"):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.device = device
if template == "":
template = r"""You are a helpful AI assistant to compose images using the below python class `Canvas`:
```python
class Canvas:
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
pass
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
assert distance_to_viewer > 0
pass
```"""
self.template = template
@staticmethod
def from_model_manager(model_manager: ModelManager):
model, model_path = model_manager.fetch_model(
"omost_prompt", require_model_path=True
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
omost = OmostPromter(
model=model, tokenizer=tokenizer, device=model_manager.device
)
return omost
def __call__(self, prompt_dict: dict):
raw_prompt = prompt_dict["prompt"]
conversation = [{"role": "system", "content": self.template}]
conversation.append({"role": "user", "content": raw_prompt})
input_ids = self.tokenizer.apply_chat_template(
conversation, return_tensors="pt", add_generation_prompt=True
).to(self.device)
streamer = TextIteratorStreamer(
self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
attention_mask = torch.ones(
input_ids.shape, dtype=torch.bfloat16, device=self.device
)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
# stopping_criteria=stopping_criteria,
# max_new_tokens=max_new_tokens,
do_sample=True,
attention_mask=attention_mask,
pad_token_id=self.tokenizer.eos_token_id,
# temperature=temperature,
# top_p=top_p,
)
self.model.generate(**generate_kwargs)
outputs = []
for text in streamer:
outputs.append(text)
llm_outputs = "".join(outputs)
canvas = Canvas.from_bot_response(llm_outputs)
canvas_output = canvas.process()
prompts = [
" ".join(_["prefixes"] + _["suffixes"][:2])
for _ in canvas_output["bag_of_conditions"]
]
canvas_output["prompt"] = prompts[0]
canvas_output["prompts"] = prompts[1:]
raw_masks = [_["mask"] for _ in canvas_output["bag_of_conditions"]]
masks = []
for mask in raw_masks:
mask[mask > 0.5] = 255
mask = np.stack([mask] * 3, axis=-1).astype("uint8")
masks.append(Image.fromarray(mask))
canvas_output["masks"] = masks
prompt_dict.update(canvas_output)
print(f"Your prompt is extended by Omost:\n")
cnt = 0
for component, pmt in zip(canvas_output["bag_of_conditions"], prompts):
loc = component["location"]
cnt += 1
print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n")
return prompt_dict