|
import os |
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
import tempfile |
|
from share_btn import share_js, save_js |
|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
from omegaconf import OmegaConf |
|
from transformers import AutoTokenizer |
|
|
|
from models import Showo, MAGVITv2, get_mask_chedule |
|
from prompting_utils import UniversalPrompting, create_attention_mask_predict_next |
|
|
|
|
|
|
|
config = OmegaConf.load("configs/showo_demo.yaml") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left") |
|
|
|
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, |
|
special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), |
|
ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob) |
|
|
|
vq_model = MAGVITv2(config.model.vq_model.type) |
|
vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device) |
|
vq_model.requires_grad_(False) |
|
vq_model.eval() |
|
|
|
model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device) |
|
model.eval() |
|
|
|
mask_token_id = model.config.mask_token_id |
|
|
|
|
|
css = """ |
|
#chatbot { min-height: 300px; } |
|
#save-btn { |
|
background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); |
|
} |
|
#save-btn:hover { |
|
background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0)); |
|
} |
|
#share-btn { |
|
background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); |
|
} |
|
#share-btn:hover { |
|
background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0)); |
|
} |
|
#gallery { z-index: 999999; } |
|
#gallery img:hover {transform: scale(2.3); z-index: 999999; position: relative; padding-right: 30%; padding-bottom: 30%;} |
|
#gallery button img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; padding-bottom: 0;} |
|
@media (hover: none) { |
|
#gallery img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; 0;} |
|
} |
|
.html2canvas-container { width: 3000px !important; height: 3000px !important; } |
|
""" |
|
|
|
|
|
def upload_image(state, image_input): |
|
conversation = state[0] |
|
chat_history = state[1] |
|
input_image = Image.open(image_input.name).resize( |
|
(224, 224)).convert('RGB') |
|
input_image.save(image_input.name) |
|
conversation += [(f'<img src="./file={image_input.name}" style="display: inline-block;">', "")] |
|
return [conversation, chat_history + [input_image, ""]], conversation |
|
|
|
|
|
def reset(): |
|
return [[], []], [] |
|
|
|
|
|
def reset_last(state): |
|
conversation = state[0][:-1] |
|
chat_history = state[1][:-2] |
|
return [conversation, chat_history], conversation |
|
|
|
|
|
def save_image_to_local(image: Image.Image): |
|
filename = next(tempfile._get_candidate_names()) + '.png' |
|
image.save(filename) |
|
return filename |
|
|
|
|
|
def text_to_image_generation(input_text, state, guidance_scale, generation_timesteps): |
|
prompts = [input_text] |
|
config.training.batch_size = config.batch_size = 1 |
|
config.training.guidance_scale = config.guidance_scale = guidance_scale |
|
config.training.generation_timesteps = config.generation_timesteps = generation_timesteps |
|
|
|
image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens), |
|
dtype=torch.long, device=device) * mask_token_id |
|
|
|
input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen') |
|
|
|
if config.training.guidance_scale > 0: |
|
uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen') |
|
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0), |
|
pad_id=int(uni_prompting.sptids_dict['<|pad|>']), |
|
soi_id=int(uni_prompting.sptids_dict['<|soi|>']), |
|
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']), |
|
rm_pad_in_image=True) |
|
else: |
|
attention_mask = create_attention_mask_predict_next(input_ids, |
|
pad_id=int(uni_prompting.sptids_dict['<|pad|>']), |
|
soi_id=int(uni_prompting.sptids_dict['<|soi|>']), |
|
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']), |
|
rm_pad_in_image=True) |
|
uncond_input_ids = None |
|
|
|
if config.get("mask_schedule", None) is not None: |
|
schedule = config.mask_schedule.schedule |
|
args = config.mask_schedule.get("params", {}) |
|
mask_schedule = get_mask_chedule(schedule, **args) |
|
else: |
|
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine")) |
|
|
|
with torch.no_grad(): |
|
gen_token_ids = model.t2i_generate( |
|
input_ids=input_ids, |
|
uncond_input_ids=uncond_input_ids, |
|
attention_mask=attention_mask, |
|
guidance_scale=config.training.guidance_scale, |
|
temperature=config.training.get("generation_temperature", 1.0), |
|
timesteps=config.training.generation_timesteps, |
|
noise_schedule=mask_schedule, |
|
noise_type=config.training.get("noise_type", "mask"), |
|
seq_len=config.model.showo.num_vq_tokens, |
|
uni_prompting=uni_prompting, |
|
config=config, |
|
) |
|
|
|
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0) |
|
images = vq_model.decode_code(gen_token_ids) |
|
|
|
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) |
|
images *= 255.0 |
|
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) |
|
pil_images = [Image.fromarray(image) for image in images] |
|
|
|
wandb_images = [wandb.Image(image, caption=prompts[i]) for i, image in enumerate(pil_images)] |
|
wandb.log({"generated_images": wandb_images}, step=step) |
|
|
|
|
|
def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperature): |
|
g_cuda = torch.Generator(device='cuda').manual_seed(1337) |
|
|
|
|
|
if len(input_text) == 0: |
|
return state, state[0], gr.update(visible=True) |
|
|
|
input_prompt = 'Q: ' + input_text + '\nA:' |
|
conversation = state[0] |
|
chat_history = state[1] |
|
print('Generating for', chat_history, flush=True) |
|
|
|
|
|
model_inputs = chat_history |
|
model_inputs.append(input_prompt) |
|
|
|
model_inputs = [s for s in model_inputs if s != ''] |
|
|
|
top_p = 1.0 |
|
if temperature != 0.0: |
|
top_p = 0.95 |
|
|
|
print('Running model.generate_for_images_and_texts with', model_inputs, flush=True) |
|
model_outputs = model.generate_for_images_and_texts(model_inputs, |
|
num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p, |
|
temperature=temperature, max_num_rets=1, |
|
num_inference_steps=50, generator=g_cuda) |
|
print('model_outputs', model_outputs, ret_scale_factor, flush=True) |
|
|
|
response = '' |
|
text_outputs = [] |
|
for output_i, p in enumerate(model_outputs): |
|
if type(p) == str: |
|
if output_i > 0: |
|
response += '<br/>' |
|
|
|
text_outputs.append(p.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', '')) |
|
response += p |
|
if len(model_outputs) > 1: |
|
response += '<br/>' |
|
elif type(p) == dict: |
|
|
|
if p['decision'] is not None and p['decision'][0] == 'gen': |
|
image = p['gen'][0][0] |
|
filename = save_image_to_local(image) |
|
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Generated)</p>' |
|
else: |
|
image = p['ret'][0][0] |
|
filename = save_image_to_local(image) |
|
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Retrieved)</p>' |
|
|
|
chat_history = model_inputs + \ |
|
[' '.join([s for s in model_outputs if type(s) == str]) + '\n'] |
|
|
|
conversation.append((input_text, response.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', ''))) |
|
|
|
|
|
print('state', state, flush=True) |
|
print('updated state', [conversation, chat_history], flush=True) |
|
return [conversation, chat_history], conversation, gr.update(visible=True), gr.update(visible=True) |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.HTML(""" |
|
<h1>π GILL</h1> |
|
<p>This is the official Gradio demo for the GILL model, a model that can process arbitrarily interleaved image and text inputs, and produce image and text outputs.</p> |
|
|
|
<strong>Paper:</strong> <a href="https://arxiv.org/abs/2305.17216" target="_blank">Generating Images with Multimodal Language Models</a> |
|
<br/> |
|
<strong>Project Website:</strong> <a href="https://jykoh.com/gill" target="_blank">GILL Website</a> |
|
<br/> |
|
<strong>Code and Models:</strong> <a href="https://github.com/kohjingyu/gill" target="_blank">GitHub</a> |
|
<br/> |
|
<br/> |
|
|
|
<strong>Tips:</strong> |
|
<ul> |
|
<li>Start by inputting either image or text prompts (or both) and chat with GILL to get image-and-text replies.</li> |
|
<li>Tweak the level of sensitivity to images and text using the parameters on the right.</li> |
|
<li>Check out cool conversations in the examples or community tab for inspiration and share your own!</li> |
|
<li>If the model outputs a blank image, it is because Stable Diffusion's safety filter detected inappropriate content. Please try again with a different prompt.</li> |
|
<li>Outputs may differ slightly from the paper due to slight implementation differences. For reproducing paper results, please use our <a href="https://github.com/kohjingyu/gill" target="_blank">official code</a>.</li> |
|
<li>For faster inference without waiting in queue, you may duplicate the space and use your own GPU: <a href="https://huggingface.co/spaces/jykoh/gill?duplicate=true"><img style="display: inline-block; margin-top: 0em; margin-bottom: 0em" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></li> |
|
</ul> |
|
""") |
|
|
|
gr_state = gr.State([[], []]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.7, min_width=500): |
|
with gr.Row(): |
|
chatbot = gr.Chatbot(elem_id="chatbot", label="π GILL Chatbot") |
|
with gr.Row(): |
|
image_btn = gr.UploadButton("πΌοΈ Upload Image", file_types=["image"]) |
|
|
|
text_input = gr.Textbox(label="Message", placeholder="Type a message") |
|
|
|
with gr.Column(): |
|
submit_btn = gr.Button("Submit", interactive=True, variant="primary") |
|
clear_last_btn = gr.Button("Undo") |
|
clear_btn = gr.Button("Reset All") |
|
with gr.Row(visible=False) as save_group: |
|
save_button = gr.Button("πΎ Save Conversation as .png", elem_id="save-btn") |
|
|
|
with gr.Row(visible=False) as share_group: |
|
share_button = gr.Button("π€ Share to Community (opens new window)", elem_id="share-btn") |
|
|
|
with gr.Column(scale=0.3, min_width=400): |
|
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.3, step=0.1, interactive=True, |
|
label="Frequency multiplier for returning images (higher means more frequent)") |
|
gr_max_len = gr.Slider(minimum=1, maximum=64, value=32, |
|
step=1, interactive=True, label="Max # of words") |
|
gr_temperature = gr.Slider( |
|
minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature (0 for deterministic, higher for more randomness)") |
|
|
|
gallery = gr.Gallery( |
|
value=[Image.open(e) for e in examples], label="Example Conversations", show_label=True, elem_id="gallery", |
|
).style(grid=[2], height="auto") |
|
|
|
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, |
|
gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group]) |
|
text_input.submit(lambda: "", None, text_input) |
|
|
|
submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor, |
|
gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group]) |
|
submit_btn.click(lambda: "", None, text_input) |
|
|
|
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot]) |
|
clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot]) |
|
clear_btn.click(reset, [], [gr_state, chatbot]) |
|
share_button.click(None, [], [], _js=share_js) |
|
save_button.click(None, [], [], _js=save_js) |
|
|
|
|
|
demo.queue(concurrency_count=1, api_open=False, max_size=16) |
|
demo.launch(debug=True, server_name="0.0.0.0") |
|
|