Spaces:
Running
on
Zero
Running
on
Zero
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| import os | |
| import gradio as gr | |
| from torchvision.transforms.functional import to_tensor | |
| from huggingface_hub import hf_hub_download, snapshot_download, login | |
| import spaces | |
| from tok.ar_dtok.ar_model import ARModel | |
| from t2i_inference import T2IConfig, TextToImageInference | |
| def generate_text(self, image: str, prompt: str) -> str: | |
| image = image.convert('RGB') | |
| image = to_tensor(image).unsqueeze(0).to(self.device) | |
| image_code = self.visual_tokenizer.encoder(image.to(self.config.dtype))['bottleneck_rep'] | |
| image_text = "".join([f"<I{x}>" for x in image_code[0].cpu().tolist()]) | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": f"{image_text}\n{prompt}"} | |
| ] | |
| input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = self.tokenizer(input_text, return_tensors="pt") | |
| gen_ids = self.model.generate( | |
| inputs.input_ids.to(self.device), | |
| max_new_tokens=512, | |
| do_sample=True) | |
| return self.tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0] | |
| login(token=os.getenv('HF_TOKEN')) | |
| config = T2IConfig() | |
| config.model = snapshot_download("ByteDance-Seed/Tar-7B") | |
| config.ar_path = { | |
| "1024px": hf_hub_download("ByteDance-Seed/Tar-TA-Tok", "ar_dtok_lp_1024px.pth"), | |
| "512px": hf_hub_download("ByteDance-Seed/Tar-TA-Tok", "ar_dtok_lp_512px.pth"), | |
| } | |
| config.encoder_path = hf_hub_download("ByteDance-Seed/Tar-TA-Tok", "ta_tok.pth") | |
| config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt") | |
| inference = TextToImageInference(config) | |
| def generate_image(prompt, resolution, top_p, top_k, cfg_scale): | |
| image = inference.generate_image(prompt, resolution, top_p, top_k, cfg_scale) | |
| return image | |
| def clear_inputs_t2i(): | |
| return "", None | |
| def understand_image(image, prompt): | |
| return generate_text(inference, image, prompt) | |
| def clear_inputs_i2t(): | |
| return None, "", "" | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| <div align="center"> | |
| ### Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations | |
| [πΈοΈ Project Page](http://tar.csuhan.com) β’ [π Paper](http://arxiv.org/abs/2506.18898) β’ [π» Code](https://github.com/csuhan/Tar) β’ [π¦ Model](https://huggingface.co/collections/ByteDance-Seed/tar-6864cf0d9fe59a3b91cc4260) | |
| </div> | |
| """, | |
| elem_id="title", | |
| ) | |
| with gr.Tab("Image Generation"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt", default="A photo of a macaw") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| resolution = gr.Radio( | |
| ["512px", "1024px"], value="1024px", label="Resolution" | |
| ) | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") | |
| top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k") | |
| cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale") | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Column(scale=2): | |
| output_image = gr.Image(label="Generated Image") | |
| generate_btn.click( | |
| generate_image, | |
| inputs=[prompt, resolution, top_p, top_k, cfg_scale], | |
| outputs=output_image | |
| ) | |
| clear_btn.click( | |
| clear_inputs_t2i, | |
| outputs=[prompt, output_image] | |
| ) | |
| with gr.Tab("Image Understanding"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(label="Upload Image", type="pil", value="https://raw.githubusercontent.com/csuhan/Tar/refs/heads/main/asset/dog_cat.jpg") | |
| question_input = gr.Textbox(label="Instruction", value="Describe the image shortly.") | |
| with gr.Row(): | |
| qa_btn = gr.Button("Generate") | |
| clear_btn_i2t = gr.Button("Clear") | |
| with gr.Column(scale=1): | |
| answer_output = gr.Textbox(label="Response", lines=4) | |
| qa_btn.click( | |
| understand_image, | |
| inputs=[image_input, question_input], | |
| outputs=answer_output | |
| ) | |
| clear_btn_i2t.click( | |
| clear_inputs_i2t, | |
| outputs=[image_input, question_input, answer_output] | |
| ) | |
| demo.launch(share=True) | |