Spaces:
Runtime error
Runtime error
import argparse | |
import datetime | |
import json | |
import os | |
import re | |
import time | |
import gradio as gr | |
import requests | |
from llava.conversation import (default_conversation, conv_templates, | |
SeparatorStyle) | |
from llava.constants import LOGDIR | |
from llava.utils import (build_logger, server_error_msg, | |
violates_moderation, moderation_msg) | |
import hashlib | |
import numpy as np | |
from PIL import Image, ImageDraw | |
from copy import deepcopy | |
logger = build_logger("gradio_web_server", "gradio_web_server.log") | |
headers = {"User-Agent": "LLaVA Client"} | |
no_change_btn = gr.Button() | |
enable_btn = gr.Button(interactive=True) | |
disable_btn = gr.Button(interactive=False) | |
priority = { | |
"vicuna-13b": "aaaaaaa", | |
"koala-13b": "aaaaaab", | |
} | |
prompt_template = ''' | |
<image> | |
Hello! Could you please help me to place {N} foreground elements over the background image of resolution {resolution} to craft an aesthetically pleasing, harmonious, balanced, and visually appealing {domain_name}? | |
Finding semantic-meaningful objects or visual foci on the background image at first might help in designing, and you should avoid any unnecessary blocking of them. | |
For each layout, there are 3 additional user requirements and you are expected to generate a layout corresponding to them. Here is the user requirements: {cons_data} | |
Please return the result by completing the following JSON file. Each element's location and size should be represented by a bounding box described as [left, top, right, bottom], and each number is a continuous digit from 0 to 1. | |
Here is the initial JSON file: {json_data} | |
''' | |
ELEM_CLASSES = { | |
"QB-Poster": ["title", "decoration", "subtitle", "itemtitle", "itemlogo", "item", "text", "textbackground", "object", "frame"], | |
"CGL": ["text", "underlay", "embellishment"], | |
"Ad Banners": ["header", "preheader", "postheader", "body text", "disclaimer / footnote", "button", "callout", "logo"] | |
} | |
CLS2COLOR = { | |
"QB-Poster": { | |
"title": "red", "subtitle": "green", "itemlogo": "orange", "item": "blue", "itemtitle": "yellow", | |
"object": "purple", "textbackground": "pink", "decoration": "brown", "frame": "gray", "text": "cyan", | |
"false": "black" | |
}, | |
"CGL": { | |
"text": "red", "underlay": "green", "embellishment": "blue", "false": "black" | |
}, | |
"Ad Banners": { | |
"header": "red", "preheader": "green", "postheader": "blue", "body text": "orange", "disclaimer / footnote": "purple", | |
"button": "pink", "callout": "brown", "logo": "gray", "false": "black" | |
} | |
} | |
def get_json_response(response): | |
for i in range(len(response)): | |
if i < len(response) - 1 and response[i:i+2] == "[{": | |
lo = i | |
elif i > 1 and response[i-1:i+1] == "}]": | |
hi = i | |
try: | |
string = response[lo:hi+1].replace("'", '"') | |
json_response = json.loads(string) | |
except: | |
json_response = None | |
return json_response | |
def draw_box(img, elems, elems2, cls2color): | |
W, H = img.size | |
drawn_outline = img.copy() | |
drawn_fill = img.copy() | |
draw_ol = ImageDraw.ImageDraw(drawn_outline) | |
draw_f = ImageDraw.ImageDraw(drawn_fill) | |
for cls, box in elems: | |
color = cls2color[cls] | |
left, top, right, bottom = box | |
_box = int(left * W), int(top * H), int(right * W), int(bottom * H) | |
draw_ol.rectangle(_box, fill=None, outline=color, width=max(10 * (W + H) // (1242 + 1660), 1)) | |
draw_f.rectangle(_box, fill=color) | |
drawn_outline = drawn_outline.convert("RGBA") | |
drawn_fill = drawn_fill.convert("RGBA") | |
drawn_fill.putalpha(int(256 * 0.1)) | |
drawn = Image.alpha_composite(drawn_outline, drawn_fill) | |
return drawn | |
def draw_boxmap(json_response, background_image, cls2color): | |
pic = background_image.convert("RGB") | |
cls_box = [(elem['label'], elem['box']) for elem in json_response] | |
print(cls_box) | |
drawn = draw_box(background_image, cls_box, cls_box, cls2color) | |
return drawn.convert("RGB") | |
def get_conv_log_filename(): | |
t = datetime.datetime.now() | |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") | |
return name | |
def get_model_list(): | |
ret = requests.post(args.controller_url + "/refresh_all_workers") | |
assert ret.status_code == 200 | |
ret = requests.post(args.controller_url + "/list_models") | |
models = ret.json()["models"] | |
models.sort(key=lambda x: priority.get(x, x)) | |
logger.info(f"Models: {models}") | |
return models | |
get_window_url_params = """ | |
function() { | |
const params = new URLSearchParams(window.location.search); | |
url_params = Object.fromEntries(params); | |
console.log(url_params); | |
return url_params; | |
} | |
""" | |
def load_demo(url_params, request: gr.Request): | |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") | |
dropdown_update = gr.Dropdown(visible=True) | |
if "model" in url_params: | |
model = url_params["model"] | |
if model in models: | |
dropdown_update = gr.Dropdown(value=model, visible=True) | |
state = default_conversation.copy() | |
return state, dropdown_update | |
def load_demo_refresh_model_list(request: gr.Request): | |
logger.info(f"load_demo. ip: {request.client.host}") | |
models = get_model_list() | |
state = default_conversation.copy() | |
dropdown_update = gr.Dropdown( | |
choices=models, | |
value=models[0] if len(models) > 0 else "" | |
) | |
return state, dropdown_update | |
def init_json(elem_list, dataset): | |
json_data = [] | |
for i, label in enumerate(ELEM_CLASSES[dataset]): | |
num = int(elem_list[i]) | |
json_data += [{"label": label, "box": []} for _ in range(num) if num > 0] | |
return json_data | |
def init_conv(request: gr.Request): | |
logger.info(f"init_conversation. ip: {request.client.host}") | |
state = default_conversation.copy() | |
return (state, state.to_gradio_chatbot()) + (enable_btn,) * 3 | |
def add_text(state, text, image, image_process_mode, request: gr.Request): | |
logger.info(f"add_text. ip: {request.client.host}.") | |
if image is not None: | |
text = (text, image, image_process_mode) | |
if len(state.get_images(return_pil=True)) > 0: | |
state = default_conversation.copy() | |
state.append_message(state.roles[0], text) | |
state.append_message(state.roles[1], None) | |
state.skip_next = False | |
return (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 | |
def qb_add_text(state, title_num, decoration_num, subtitle_num, itemtitle_num, itemlogo_num, item_num, text_num, | |
textbackground_num, object_num, frame_num, image, user_cons, image_process_mode, request: gr.Request): | |
elem_list = [title_num, decoration_num, subtitle_num, itemtitle_num, itemlogo_num, \ | |
item_num, text_num, textbackground_num, object_num, frame_num] | |
json_data = init_json(elem_list, dataset='QB-Poster') | |
if image is not None: | |
resolution = list(image.size) | |
else: | |
try: | |
resolution = list(state.get_images(return_pil=True)[-1].size) | |
except: | |
resolution = [1242, 1660] | |
text = prompt_template.replace('<image>\n', '<image>').replace('\n', '\\n').format( | |
N=len(json_data), | |
resolution=resolution, | |
domain_name="poster with xiaohonshu style", | |
cons_data=user_cons, | |
json_data=json.dumps(json_data) | |
) | |
return add_text(state, text, image, image_process_mode, request) | |
def cgl_add_text(state, text_num, underlay_num, embellishment_num, image, user_cons, image_process_mode, request: gr.Request): | |
elem_list = [text_num, underlay_num, embellishment_num] | |
json_data = init_json(elem_list, dataset='CGL') | |
if image is not None: | |
resolution = list(image.size) | |
else: | |
try: | |
resolution = list(state.get_images(return_pil=True)[-1].size) | |
except: | |
resolution = [513, 750] | |
text = prompt_template.replace('<image>\n', '<image>').replace('\n', '\\n').format( | |
N=len(json_data), | |
resolution=resolution, | |
domain_name="commercial poster", | |
cons_data=user_cons, | |
json_data=json.dumps(json_data) | |
) | |
return add_text(state, text, image, image_process_mode, request) | |
def banners_add_text(state, header_num, preheader_num, postheader_num, body_text, disclaimer_num, button_num, | |
callout_num, logo_num, image, user_cons, image_process_mode, request: gr.Request): | |
elem_list = [header_num, preheader_num, postheader_num, body_text, disclaimer_num, button_num, callout_num, logo_num] | |
json_data = init_json(elem_list, dataset='Ad Banners') | |
if image is not None: | |
resolution = list(image.size) | |
else: | |
try: | |
resolution = list(state.get_images(return_pil=True)[-1].size) | |
except: | |
resolution = [1080, 1080] | |
text = prompt_template.replace('<image>\n', '<image>').replace('\n', '\\n').format( | |
N=len(json_data), | |
resolution=resolution, | |
domain_name="commercial banner", | |
cons_data=user_cons, | |
json_data=json.dumps(json_data) | |
) | |
return add_text(state, text, image, image_process_mode, request) | |
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, repeat_times, request: gr.Request, progress = gr.Progress()): | |
logger.info(f"http_bot. ip: {request.client.host}") | |
start_tstamp = time.time() | |
model_name = model_selector | |
if state.skip_next: | |
yield (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 3 | |
return | |
if len(state.messages) == state.offset + 2: | |
template_name = "llava_v1" | |
new_state = conv_templates[template_name].copy() | |
new_state.append_message(new_state.roles[0], state.messages[-2][1]) | |
new_state.append_message(new_state.roles[1], None) | |
state = new_state | |
# Query worker address | |
controller_url = args.controller_url | |
ret = requests.post(controller_url + "/get_worker_address", | |
json={"model": model_name}) | |
worker_addr = ret.json()["address"] | |
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") | |
# No available worker | |
if worker_addr == "": | |
state.messages[-1][-1] = server_error_msg | |
yield (state, state.to_gradio_chatbot(), None, disable_btn, disable_btn, disable_btn) | |
return | |
# Construct prompt | |
prompt = state.get_prompt() | |
if "xiaohonshu" in prompt: | |
current_dataset = "QB-Poster" | |
elif "commercial poster" in prompt: | |
current_dataset = "CGL" | |
elif "commercial banner" in prompt: | |
current_dataset = "Ad Banners" | |
all_images = state.get_images(return_pil=True) | |
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] | |
for image, hash in zip(all_images, all_image_hash): | |
t = datetime.datetime.now() | |
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") | |
if not os.path.isfile(filename): | |
os.makedirs(os.path.dirname(filename), exist_ok=True) | |
image.save(filename) | |
# Make requests | |
pload = { | |
"model": model_name, | |
"prompt": prompt, | |
"temperature": float(temperature), | |
"top_p": float(top_p), | |
"max_new_tokens": int(max_new_tokens), | |
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, | |
"images": f'List of {len(state.get_images())} images: {all_image_hash}', | |
} | |
logger.info(f"==== request ====\n{pload}") | |
pload['images'] = state.get_images() | |
boxmaps, all_responses = [], [] | |
initial_json = re.findall(r'\[\{.*?\}\]', prompt)[0] | |
elems_num = initial_json.count("label") | |
total_length = elems_num * len('0.0000, 0.0000, 0.0000, 0.0000') + len(initial_json)\ | |
+ len('Sure! Here is the design results: ') | |
for t in range(repeat_times): | |
try: | |
response = requests.post(worker_addr + "/worker_generate_stream", | |
headers=headers, json=pload, stream=True, timeout=20) | |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): | |
if chunk: | |
data = json.loads(chunk.decode()) | |
if data["error_code"] == 0: | |
output = data["text"][len(prompt):].strip() | |
else: | |
output = data["text"] + f" (error_code: {data['error_code']})" | |
state.messages[-1][-1] = output | |
yield (state, state.to_gradio_chatbot(), boxmaps) + (enable_btn,) * 3 | |
return | |
time.sleep(0.01) | |
p = (len(output) + len(''.join(all_responses))) / (total_length * repeat_times) | |
progress(p, desc=f'Generating the {t + 1}th output...') | |
except requests.exceptions.RequestException as e: | |
state.messages[-1][-1] = server_error_msg | |
yield (state, state.to_gradio_chatbot(), boxmaps) + (enable_btn,) * 3 | |
return | |
all_responses.append(output) | |
json_response = get_json_response(output) | |
if json_response is not None: | |
boxmaps.append(draw_boxmap(json_response, all_images[-1], CLS2COLOR[current_dataset])) | |
state.messages[-1][-1] = "".join([f"Design Result {i}:\n" + all_responses[i] + "\n\n" for i in range(len(all_responses))]) | |
yield (state, state.to_gradio_chatbot(), boxmaps) + (enable_btn,) * 3 | |
finish_tstamp = time.time() | |
logger.info(f"{output}") | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name, | |
"start": round(start_tstamp, 4), | |
"finish": round(start_tstamp, 4), | |
"state": state.dict(), | |
"images": all_image_hash, | |
"ip": request.client.host, | |
} | |
fout.write(json.dumps(data) + "\n") | |
title_markdown = (""" | |
# PosterLLaVA: Constructing a Unified Multi-modal Layout Generator with LLM | |
""") | |
tos_markdown = (""" | |
""") | |
learn_more_markdown = (""" | |
""") | |
block_css = """ | |
#buttons button { | |
min-width: min(120px,100%); | |
} | |
""" | |
def build_demo(embed_mode, cur_dir): | |
imagebox_boxmap = gr.Gallery(label='Result(结果)', show_label=True, preview=False, columns=2, allow_preview=True, height=550) | |
with gr.Blocks(title="PosterLLaVA", theme=gr.themes.Default(), css=block_css) as demo: | |
state = gr.State() | |
if not embed_mode: | |
gr.Markdown(title_markdown) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(elem_id="model_selector_row"): | |
model_selector = gr.Dropdown( | |
choices=models, | |
value=models[0] if len(models) > 0 else "", | |
interactive=True, | |
show_label=False, | |
container=False) | |
imagebox = gr.Image(type="pil") | |
image_process_mode = gr.Radio( | |
["Crop", "Resize", "Pad", "Default"], | |
value="Default", | |
label="Preprocess for non-square image", visible=False) | |
with gr.Accordion("Parameters", open=True) as parameter_row: | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) | |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=1024, step=64, interactive=True, label="Max output tokens",) | |
repeat_times = gr.Slider(minimum=1, maximum=64, value=4, step=1, interactive=True, label="Repeat times",) | |
# cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
# gr.Examples(examples=[ | |
# [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"], | |
# [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], | |
# ], inputs=[imagebox, textbox]) | |
with gr.Column(scale=8): | |
imagebox_boxmap.render() | |
with gr.Tabs() as tabs: | |
with gr.Tab("QB-Poster"): | |
with gr.Row(variant='compact'): | |
object_num = gr.Checkbox(label='object', value=True) | |
frame_num = gr.Checkbox(label='frame', value=True) | |
title_num = gr.Checkbox(label='title', value=True) | |
decoration_num = gr.Slider(minimum=0, maximum=10, value=2, step=1, label='decoration') | |
subtitle_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='subtitle') | |
itemtitle_num = gr.Slider(minimum=0, maximum=10, value=0, step=1, label='itemtitle') | |
itemlogo_num = gr.Slider(minimum=0, maximum=10, value=3, step=1, label='itemlogo') | |
item_num = gr.Slider(minimum=0, maximum=10, value=3, step=1, label='item') | |
text_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='text') | |
textbackground_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='textbackground') | |
qb_elem_list = [title_num, decoration_num, subtitle_num, itemtitle_num, itemlogo_num, | |
item_num, text_num, textbackground_num, object_num, frame_num] | |
with gr.Row(): | |
qb_textbox = gr.Textbox(label='User Constraint', show_label=True, | |
placeholder="Enter text and press ENTER", container=True) | |
with gr.Column(scale=1, min_width=50): | |
qb_submit_btn = gr.Button(value="Generate", variant="primary", interactive=False) | |
with gr.Tab("CGL / PosterLayout"): | |
with gr.Row(variant='compact'): | |
cgl_text_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='text') | |
underlay_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='underlay') | |
embellishment_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='embellishment') | |
cgl_elem_list = [cgl_text_num, underlay_num, embellishment_num] | |
with gr.Row(): | |
cgl_textbox = gr.Textbox(label='User Constraint', show_label=True, | |
placeholder="Enter text and press ENTER", container=True) | |
with gr.Column(scale=1, min_width=50): | |
cgl_submit_btn = gr.Button(value="Generate", variant="primary", interactive=False) | |
with gr.Tab("Ad Banners"): | |
with gr.Row(variant='compact'): | |
header_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='header') | |
preheader_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='pre-header') | |
postheader_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='post-header') | |
body_text = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='body text') | |
disclaimer_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='disclaimer / footnote') | |
button_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='button') | |
callout_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='callout') | |
logo_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='logo') | |
banners_elem_list = [header_num, preheader_num, postheader_num, body_text, disclaimer_num, | |
button_num, callout_num, logo_num] | |
with gr.Row(): | |
banners_textbox = gr.Textbox(label='User Constraint', show_label=True, | |
placeholder="Enter your customized design requirements separated by ';' to control the sizes and positions of elements", container=True) | |
with gr.Column(scale=1, min_width=50): | |
banners_submit_btn = gr.Button(value="Generate", variant="primary", interactive=False) | |
with gr.Accordion("Intermediate results", open=False): | |
gr.Markdown("The layout generation process with LLM") | |
chatbot = gr.Chatbot(elem_id="chatbot", label="LLM Conversations", height=550) | |
if not embed_mode: | |
gr.Markdown(tos_markdown) | |
gr.Markdown(learn_more_markdown) | |
url_params = gr.JSON(visible=False) | |
# Register listeners | |
btn_list = [qb_submit_btn, cgl_submit_btn, banners_submit_btn] | |
imagebox.change( | |
init_conv, | |
None, | |
[state, chatbot] + btn_list, | |
queue=False | |
) | |
qb_submit_btn.click( | |
qb_add_text, | |
[state, *qb_elem_list, imagebox, qb_textbox, image_process_mode], | |
[state, chatbot] + btn_list, | |
queue=False | |
).then( | |
http_bot, | |
[state, model_selector, temperature, top_p, max_output_tokens, repeat_times], | |
[state, chatbot, imagebox_boxmap] + btn_list | |
) | |
cgl_submit_btn.click( | |
cgl_add_text, | |
[state, *cgl_elem_list, imagebox, cgl_textbox, image_process_mode], | |
[state, chatbot] + btn_list, | |
queue=False | |
).then( | |
http_bot, | |
[state, model_selector, temperature, top_p, max_output_tokens, repeat_times], | |
[state, chatbot, imagebox_boxmap] + btn_list | |
) | |
banners_submit_btn.click( | |
banners_add_text, | |
[state, *banners_elem_list, imagebox, banners_textbox, image_process_mode], | |
[state, chatbot] + btn_list, | |
queue=False | |
).then( | |
http_bot, | |
[state, model_selector, temperature, top_p, max_output_tokens, repeat_times], | |
[state, chatbot, imagebox_boxmap] + btn_list | |
) | |
if args.model_list_mode == "once": | |
demo.load( | |
load_demo, | |
[url_params], | |
[state, model_selector], | |
_js=get_window_url_params, | |
queue=False | |
) | |
elif args.model_list_mode == "reload": | |
demo.load( | |
load_demo_refresh_model_list, | |
None, | |
[state, model_selector], | |
queue=False | |
) | |
else: | |
raise ValueError(f"Unknown model list mode: {args.model_list_mode}") | |
return demo | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int) | |
parser.add_argument("--controller-url", type=str, default="http://localhost:21001") | |
parser.add_argument("--concurrency-count", type=int, default=10) | |
parser.add_argument("--model-list-mode", type=str, default="once", | |
choices=["once", "reload"]) | |
parser.add_argument("--share", action="store_true") | |
parser.add_argument("--embed", action="store_true") | |
args = parser.parse_args() | |
logger.info(f"args: {args}") | |
models = get_model_list() | |
logger.info(args) | |
demo = build_demo(args.embed) | |
demo.queue( | |
concurrency_count=args.concurrency_count, | |
api_open=False | |
).launch( | |
server_name=args.host, | |
server_port=args.port, | |
share=args.share | |
) |