import gradio as gr import numpy as np import random import json import requests from huggingface_hub import InferenceClient import torch from dotenv import load_dotenv import os import spaces from transformers import AutoConfig, AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor from janus.utils.io import load_pil_images load_dotenv() token = os.environ.get("HF_TOKEN") MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 model_paths = "deepseek-ai/Janus-Pro-7B" config = AutoConfig.from_pretrained(model_paths) model_paths = "deepseek-ai/Janus-Pro-7B" config = AutoConfig.from_pretrained(model_paths) language_config = config.language_config language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained(model_paths, language_config=language_config, trust_remote_code=True) vl_gpt = AutoModelForCausalLM.from_pretrained(model_paths, language_config=language_config, trust_remote_code=True) if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() else: vl_gpt = vl_gpt.to(torch.float16) device = "cuda" if torch.cuda.is_available() else "cpu" vl_chat_processor = VLChatProcessor.from_pretrained(model_paths) vl_chat_processor = VLChatProcessor.from_pretrained(model_paths) tokenizer = vl_chat_processor.tokenizer cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 1, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) pkv = None for i in range(image_token_num_per_image): with torch.no_grad(): outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) pkv = outputs.past_key_values hidden_states = outputs.last_hidden_state logits = vl_gpt.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, width // patch_size, height // patch_size]) return generated_tokens.to(dtype=torch.int), patches def unpack(dec, width, height, parallel_size=1): dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) visual_img[:, :, :] = dec return visual_img def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 1, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) pkv = None for i in range(image_token_num_per_image): with torch.no_grad(): outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) pkv = outputs.past_key_values hidden_states = outputs.last_hidden_state logits = vl_gpt.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, width // patch_size, height // patch_size]) return generated_tokens.to(dtype=torch.int), patches def unpack(dec, width, height, parallel_size=1): dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) visual_img[:, :, :] = dec return visual_img def load_seeds(): try: with open('seeds.json', 'r') as f: return json.load(f) except FileNotFoundError: print("seeds.json not found") return {} def generate_prompt(prompt_seed): headers = {"Authorization": f"Bearer {token}", "x-use-cache": "0", "x-wait-for-model": "1", 'Content-Type': 'application/json'} API_URL = "https://api-inference.huggingface.co/models/gokaygokay/Flux-Prompt-Enhance" apiData = {"inputs": prompt_seed, "parameters": {"max_new_tokens": 250, }, "stream": "0"} try: response = requests.post(API_URL, headers=headers, data=json.dumps(apiData)) return response.json()[0]["generated_text"] except Exception as e: print(f"Error generating prompt: {e}") return "Error generating prompt. Please try again." def prompt_generator(): seeds = load_seeds() if seeds: seed = random.choice(seeds["seeds"]) # Randomly select a seed from the JSON return generate_prompt(seed) return "Unable to generate prompt - no seeds available" @torch.inference_mode() @spaces.GPU() def infer( prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, temperature, num_inference_steps, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) with torch.no_grad(): messages = [ {'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''} ] text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt='' ) text += vl_chat_processor.image_start_tag input_ids = torch.LongTensor(tokenizer.encode(text)) output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance_scale, parallel_size=1, temperature=temperature) images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=1) torch.cuda.empty_cache() print(images) return images return 'images' examples = [ "A cyberpunk street market at night, neon signs reflecting in puddles, steam rising from food stalls, highly detailed, cinematic lighting, 8k resolution, shallow depth of field", "Portrait of an elderly wise woman with intricate tribal face paint, wrinkled skin with fine details, wearing ornate traditional jewelry, soft natural lighting, studio photography style, 85mm lens", "Mystical underwater ruins of an ancient temple, bioluminescent coral growing on marble columns, schools of colorful fish swimming by, rays of sunlight penetrating the water, photorealistic render", "A cozy cottage kitchen at sunrise, morning light streaming through windows, steam rising from fresh bread and coffee, vintage copper pots hanging overhead, hyperrealistic style, warm color palette", "An impossible M.C. Escher-inspired geometric cityscape where stairs connect in physically impossible ways, architectural details, isometric perspective, clean lines, muted colors, high contrast" ] css = """ #col-container { margin: 0 auto; max-width: 640px; } """ with gr.Blocks(css=css) as demo: gr.HTML("""