Spaces:
Runtime error
Runtime error
| import subprocess | |
| subprocess.run("pip install gradio==4.44.0", shell=True) | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| AutoImageProcessor, | |
| AutoModel, | |
| ) | |
| from transformers.generation.configuration_utils import GenerationConfig | |
| from transformers.generation import ( | |
| LogitsProcessorList, | |
| PrefixConstrainedLogitsProcessor, | |
| UnbatchedClassifierFreeGuidanceLogitsProcessor, | |
| ) | |
| import torch | |
| from emu3.mllm.processing_emu3 import Emu3Processor | |
| import spaces | |
| import io | |
| import base64 | |
| def image2str(image): | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG") | |
| i_str = base64.b64encode(buf.getvalue()).decode() | |
| return f'<div style="float:left"><img src="data:image/png;base64, {i_str}"></div>' | |
| # Install flash attention, skipping CUDA build if necessary | |
| subprocess.run( | |
| "pip install flash-attn --no-build-isolation", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, | |
| ) | |
| print(gr.__version__) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Model paths | |
| EMU_GEN_HUB = "BAAI/Emu3-Gen" | |
| EMU_CHAT_HUB = "BAAI/Emu3-Chat" | |
| VQ_HUB = "BAAI/Emu3-VisionTokenizer" | |
| # uncomment to use gen model | |
| """ | |
| # Prepare models and processors | |
| # Emu3-Gen model and processor | |
| gen_model = AutoModelForCausalLM.from_pretrained( | |
| EMU_GEN_HUB, | |
| device_map="cpu", | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2", | |
| trust_remote_code=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True) | |
| image_processor = AutoImageProcessor.from_pretrained( | |
| VQ_HUB, trust_remote_code=True | |
| ) | |
| image_tokenizer = AutoModel.from_pretrained( | |
| VQ_HUB, device_map="cpu", trust_remote_code=True | |
| ).eval() | |
| print(device) | |
| gen_model.to(device) | |
| image_tokenizer.to(device) | |
| processor = Emu3Processor( | |
| image_processor, image_tokenizer, tokenizer | |
| ) | |
| @spaces.GPU(duration=300) | |
| def generate_image(prompt): | |
| POSITIVE_PROMPT = " masterpiece, film grained, best quality." | |
| NEGATIVE_PROMPT = ( | |
| "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, " | |
| "fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, " | |
| "signature, watermark, username, blurry." | |
| ) | |
| classifier_free_guidance = 3.0 | |
| full_prompt = prompt + POSITIVE_PROMPT | |
| kwargs = dict( | |
| mode="G", | |
| ratio="1:1", | |
| image_area=gen_model.config.image_area, | |
| return_tensors="pt", | |
| ) | |
| pos_inputs = processor(text=full_prompt, **kwargs) | |
| neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs) | |
| # Prepare hyperparameters | |
| GENERATION_CONFIG = GenerationConfig( | |
| use_cache=True, | |
| eos_token_id=gen_model.config.eos_token_id, | |
| pad_token_id=gen_model.config.pad_token_id, | |
| max_new_tokens=40960, | |
| do_sample=True, | |
| top_k=2048, | |
| ) | |
| h, w = pos_inputs.image_size[0] | |
| constrained_fn = processor.build_prefix_constrained_fn(h, w) | |
| logits_processor = LogitsProcessorList( | |
| [ | |
| UnbatchedClassifierFreeGuidanceLogitsProcessor( | |
| classifier_free_guidance, | |
| gen_model, | |
| unconditional_ids=neg_inputs.input_ids.to(device), | |
| ), | |
| PrefixConstrainedLogitsProcessor( | |
| constrained_fn, | |
| num_beams=1, | |
| ), | |
| ] | |
| ) | |
| # Generate | |
| outputs = gen_model.generate( | |
| pos_inputs.input_ids.to(device), | |
| generation_config=GENERATION_CONFIG, | |
| logits_processor=logits_processor, | |
| ) | |
| mm_list = processor.decode(outputs[0]) | |
| for idx, im in enumerate(mm_list): | |
| if isinstance(im, Image.Image): | |
| return im | |
| return None | |
| def chat(history, user_input, user_image): | |
| if user_image is not None: | |
| history = history + [("", "Sorry, gen model do not accept image input")] | |
| else: | |
| # Use Emu3-Gen for image generation | |
| generated_image = generate_image(user_input) | |
| if generated_image is not None: | |
| # Append the user input and generated image to the history | |
| history = history + [(user_input, image2str(generated_image))] | |
| else: | |
| # If image generation failed, respond with an error message | |
| history = history + [ | |
| (user_input, "Sorry, I could not generate an image.") | |
| ] | |
| return history, history, gr.update(value=None) | |
| """ | |
| # Emu3-Chat model and processor | |
| chat_model = AutoModelForCausalLM.from_pretrained( | |
| EMU_CHAT_HUB, | |
| device_map="cpu", | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2", | |
| trust_remote_code=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True) | |
| image_processor = AutoImageProcessor.from_pretrained( | |
| VQ_HUB, trust_remote_code=True | |
| ) | |
| image_tokenizer = AutoModel.from_pretrained( | |
| VQ_HUB, device_map="cpu", trust_remote_code=True | |
| ).eval() | |
| print(device) | |
| chat_model.to(device) | |
| image_tokenizer.to(device) | |
| processor = Emu3Processor( | |
| image_processor, image_tokenizer, tokenizer | |
| ) | |
| def vision_language_understanding(image, text): | |
| inputs = processor( | |
| text=text, | |
| image=image, | |
| mode="U", | |
| padding_side="left", | |
| padding="longest", | |
| return_tensors="pt", | |
| ) | |
| # Prepare hyperparameters | |
| GENERATION_CONFIG = GenerationConfig( | |
| pad_token_id=tokenizer.pad_token_id, | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| max_new_tokens=320, | |
| ) | |
| # Generate | |
| outputs = chat_model.generate( | |
| inputs.input_ids.to(device), | |
| generation_config=GENERATION_CONFIG, | |
| max_new_tokens=320, | |
| ) | |
| outputs = outputs[:, inputs.input_ids.shape[-1] :] | |
| response = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| return response | |
| def chat(history, user_input, user_image): | |
| if user_image is not None: | |
| # Use Emu3-Chat for vision-language understanding | |
| response = vision_language_understanding(user_image, user_input) | |
| # Append the user input and response to the history | |
| history = history + [(image2str(user_image) + "<br>" + user_input, response)] | |
| else: | |
| history = history + [(user_input, "Sorry, please specify a valid image for vl understanding.")] | |
| return history, history, gr.update(value=None) | |
| def clear_input(): | |
| return gr.update(value="") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Emu3 Chatbot Demo") | |
| gr.Markdown( | |
| "This is a chatbot demo for image generation and vision-language understanding using Emu3 models." | |
| ) | |
| chatbot = gr.Chatbot() | |
| state = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=0.85): | |
| user_input = gr.Textbox( | |
| show_label=False, placeholder="Type your message here...", lines=2, container=False, | |
| ) | |
| with gr.Column(scale=0.15, min_width=0): | |
| submit_btn = gr.Button("Send") | |
| user_image = gr.Image( | |
| sources="upload", type="pil", label="Upload an image (optional)" | |
| ) | |
| submit_btn.click( | |
| chat, | |
| inputs=[state, user_input, user_image], | |
| outputs=[chatbot, state, user_image], | |
| ).then(fn=clear_input, inputs=[], outputs=user_input) | |
| user_input.submit( | |
| chat, | |
| inputs=[state, user_input, user_image], | |
| outputs=[chatbot, state, user_image], | |
| ).then(fn=clear_input, inputs=[], outputs=user_input) | |
| demo.launch() | |