Spaces:
Sleeping
Sleeping
| import re | |
| from dataclasses import dataclass | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from transformers import AutoTokenizer, Qwen2ForCausalLM | |
| from tok.mm_autoencoder import MMAutoEncoder | |
| class T2IConfig: | |
| model_path: str = "csuhan/Tar-1.5B" | |
| # visual tokenizer config | |
| ar_path: str = 'ar_dtok_lp_256px.pth' | |
| encoder_path: str = 'ta_tok.pth' | |
| decoder_path: str = 'vq_ds16_t2i.pt' | |
| device: str = "cuda:0" | |
| dtype: torch.dtype = torch.bfloat16 | |
| # generation parameters | |
| scale: int = 0 # choose from [0, 1, 2] | |
| seq_len: int = 729 # choose from [729, 169, 81] | |
| temperature: float = 1.0 | |
| top_p: float = 0.95 | |
| top_k: int = 1200 | |
| cfg_scale: float = 4.0 | |
| class TextToImageInference: | |
| def __init__(self, config: T2IConfig): | |
| self.config = config | |
| self.device = torch.device(config.device) | |
| self._load_models() | |
| def _load_models(self): | |
| self.model = Qwen2ForCausalLM.from_pretrained(self.config.model_path, torch_dtype=self.config.dtype).to(self.device) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) | |
| # Initialize visual tokenizer | |
| config = dict( | |
| ar_path=self.config.ar_path, | |
| encoder_path=self.config.encoder_path, | |
| decoder_path=self.config.decoder_path, | |
| encoder_args={'input_type': 'rec'}, | |
| decoder_args={}, | |
| ) | |
| self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device) | |
| self.visual_tokenizer.ar_model.cls_token_num = self.config.seq_len | |
| self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1 | |
| def generate_image(self, prompt: str) -> Image.Image: | |
| # Prepare prompt | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| input_text = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True) | |
| input_text += f"<im_start><S{self.config.scale}>" | |
| # Generate tokens | |
| inputs = self.tokenizer(input_text, return_tensors="pt") | |
| gen_ids = self.model.generate( | |
| inputs.input_ids.to(self.device), | |
| max_new_tokens=self.config.seq_len, | |
| do_sample=True, | |
| temperature=self.config.temperature, | |
| top_p=self.config.top_p, | |
| top_k=self.config.top_k) | |
| # Process generated tokens | |
| gen_text = self.tokenizer.batch_decode(gen_ids)[0] | |
| gen_code = [int(x) for x in re.findall(r'<I(\d+)>', gen_text)] | |
| gen_code = gen_code[:self.config.seq_len] + [0] * max(0, self.config.seq_len - len(gen_code)) | |
| gen_code = torch.tensor(gen_code).unsqueeze(0).to(self.device) | |
| gen_tensor = self.visual_tokenizer.decode_from_encoder_indices( | |
| gen_code, | |
| {'cfg_scale': self.config.cfg_scale} | |
| ) | |
| gen_image = Image.fromarray(gen_tensor[0].numpy()) | |
| return gen_image | |
| def main(): | |
| config = T2IConfig() | |
| config.ar_path = hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth") | |
| config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth") | |
| config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt") | |
| inference = TextToImageInference(config) | |
| prompt = "A photo of a macaw" | |
| image = inference.generate_image(prompt) | |
| image.save("generated_image.png") | |
| if __name__ == "__main__": | |
| main() |