Spaces:
Running
on
Zero
Running
on
Zero
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| import re | |
| from dataclasses import dataclass | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from transformers import AutoTokenizer, Qwen2ForCausalLM | |
| from tok.mm_autoencoder import MMAutoEncoder | |
| class T2IConfig: | |
| model_path: str = "ByteDance-Seed/Tar-1.5B" | |
| # visual tokenizer config | |
| ar_path = None | |
| encoder_path: str = 'ta_tok.pth' | |
| decoder_path: str = 'vq_ds16_t2i.pt' | |
| device: str = "cuda:0" | |
| dtype: torch.dtype = torch.bfloat16 | |
| # generation parameters | |
| scale: int = 0 # choose from [0, 1, 2] | |
| seq_len: int = 729 # choose from [729, 169, 81] | |
| temperature: float = 1.0 | |
| top_p: float = 0.95 | |
| top_k: int = 1200 | |
| cfg_scale: float = 4.0 | |
| class TextToImageInference: | |
| def __init__(self, config: T2IConfig): | |
| self.config = config | |
| self.device = torch.device(config.device) | |
| self._load_models() | |
| def _load_models(self): | |
| self.model = Qwen2ForCausalLM.from_pretrained(self.config.model_path, torch_dtype=self.config.dtype).to(self.device) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) | |
| # Initialize visual tokenizer | |
| config = dict( | |
| ar_path_dict=self.config.ar_path, | |
| encoder_path=self.config.encoder_path, | |
| decoder_path=self.config.decoder_path, | |
| encoder_args={'input_type': 'rec'}, | |
| decoder_args={}, | |
| ) | |
| self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device) | |
| for ar_model in self.visual_tokenizer.ar_model.values(): | |
| ar_model.cls_token_num = self.config.seq_len | |
| self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1 | |
| def generate_image(self, prompt, resolution, top_p, top_k, cfg_scale) -> Image.Image: | |
| # Prepare prompt | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| input_text = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True) | |
| input_text += f"<im_start><S{self.config.scale}>" | |
| # Generate tokens | |
| inputs = self.tokenizer(input_text, return_tensors="pt") | |
| gen_ids = self.model.generate( | |
| inputs.input_ids.to(self.device), | |
| max_new_tokens=self.config.seq_len, | |
| do_sample=True, | |
| temperature=self.config.temperature, | |
| top_p=top_p, | |
| top_k=top_k) | |
| # Process generated tokens | |
| gen_text = self.tokenizer.batch_decode(gen_ids)[0] | |
| gen_code = [int(x) for x in re.findall(r'<I(\d+)>', gen_text)] | |
| gen_code = gen_code[:self.config.seq_len] + [0] * max(0, self.config.seq_len - len(gen_code)) | |
| gen_code = torch.tensor(gen_code).unsqueeze(0).to(self.device) | |
| gen_tensor = self.visual_tokenizer.decode_from_encoder_indices( | |
| gen_code, | |
| {'cfg_scale': cfg_scale, 'resolution': resolution}, | |
| ) | |
| gen_image = Image.fromarray(gen_tensor[0].numpy()) | |
| return gen_image |