import os import yaml import subprocess import sys import spaces import numpy as np from nsfw_detector import NSFWDetector, create_error_image from PIL import Image import time # import logging from threading import Timer # logging.basicConfig(level=logging.INFO) # logger = logging.getLogger(__name__) # Global variables global_model = None last_use_time = None unload_timer = None TIMEOUT_SECONDS = 120 # 2 minutes # Clone the repository if not os.path.exists('Sana'): subprocess.run(['git', 'clone', 'https://github.com/NVlabs/Sana.git']) # Change to Sana directory os.chdir('Sana') # Workarounds def modify_builder(): builder_path = 'diffusion/model/builder.py' with open(builder_path, 'r') as f: content = f.readlines() # Find the text_encoder_dict definition for i, line in enumerate(content): if 'text_encoder_dict = {' in line: content.insert(i + 11, ' "unsloth-gemma-2-2b-it": "unsloth/gemma-2-2b-it",\n') break with open(builder_path, 'w') as f: f.writelines(content) def modify_config(): config_path = 'configs/sana_config/1024ms/Sana_1600M_img1024.yaml' with open(config_path, 'r') as f: config = yaml.safe_load(f) # Update text encoder config['text_encoder']['text_encoder_name'] = 'unsloth-gemma-2-2b-it' config['model']['mixed_precision'] = 'bf16' with open(config_path, 'w') as f: yaml.dump(config, f, default_flow_style=False) # Run environment setup commands setup_commands = [ "pip install torch", # init raw torch "pip install -U pip", # update pip "pip install -U xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu121", # fast attn "pip install pyyaml", "pip install -e ." # install sana ] for cmd in setup_commands: print(f"Running: {cmd}") subprocess.run(cmd.split()) import torch import gradio as gr sys.path.append('.') # Modify config and builder before importing SanaPipeline modify_config() modify_builder() from Sana.app.sana_pipeline import SanaPipeline def unload_model(): global global_model, last_use_time current_time = time.time() if last_use_time and (current_time - last_use_time) >= TIMEOUT_SECONDS: # logger.info("Unloading model due to inactivity...") global_model = None torch.cuda.empty_cache() return "Model unloaded due to inactivity" def reset_timer(): global unload_timer, last_use_time if unload_timer: unload_timer.cancel() last_use_time = time.time() unload_timer = Timer(TIMEOUT_SECONDS, unload_model) unload_timer.start() @spaces.GPU(duration=110) def generate_image(prompt, height, width, guidance_scale, pag_guidance_scale, num_inference_steps): global global_model try: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() # Load model if needed if global_model is None: # logger.info("Loading model...") global_model = SanaPipeline("configs/sana_config/1024ms/Sana_1600M_img1024.yaml") global_model.from_pretrained("hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth") reset_timer() # Random seed generator = torch.Generator(device=device).manual_seed(int(time.time())) image = global_model( prompt=prompt, height=height, width=width, guidance_scale=guidance_scale, pag_guidance_scale=pag_guidance_scale, num_inference_steps=num_inference_steps, generator=generator, ) # Convert tensor to PIL Image image = ((image[0] + 1) / 2).float().cpu() image = (image * 255).clamp(0, 255).numpy().astype(np.uint8) image = Image.fromarray(image.transpose(1, 2, 0)) # Check for NSFW content detector = NSFWDetector() is_nsfw, category, confidence = detector.check_image(image) if category == "SAFE": return image else: # logger.warning(f"NSFW content detected ({category} with {confidence:.2f}% confidence)") return create_error_image() except Exception as e: # logger.error(f"Error in generate_image: {str(e)}") raise gr.Error(f"Generation failed: {str(e)}") # Gradio Interface with gr.Blocks(theme=gr.themes.Default(), css=""".center-text {text-align: center;} .footer-link {text-align: center; margin: 20px 0;} .slider-pad {margin-bottom: 24px;}""") as interface: with gr.Row(elem_id="banner"): with gr.Column(): gr.Markdown("# Sana 1.6B", elem_classes="center-text") gr.Markdown("Generate high-resolution images up to 4096x4096 using the Sana 1.6B model, fast.", elem_classes="center-text") with gr.Row(): with gr.Column(scale=2): prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3) with gr.Row(): with gr.Column(): height = gr.Slider(minimum=512, maximum=4096, step=64, value=1024, label="Height") width = gr.Slider(minimum=512, maximum=4096, step=64, value=1024, label="Width") with gr.Column(): guidance_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=5.0, label="Guidance Scale") pag_guidance_scale = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, value=2.0, label="PAG Guidance Scale") num_inference_steps = gr.Slider(minimum=2, maximum=50, step=1, value=18, label="Number of Steps") gr.Markdown("*Note: Higher guidance scales provide stronger adherence to the prompt. PAG guidance helps with image-text alignment.*") gr.Markdown("⏱️ Be patient, the model loads into memory slow first time around.") generate_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=2): output = gr.Image(label="Generated Image", height=512) # Examples section gr.Examples( examples=[ ["a cyberpunk cat with a neon sign that says 'Sana'", 1024, 1024, 5.0, 2.0, 18], ["a beautiful sunset over a mountain landscape", 1024, 1024, 5.0, 2.0, 18], ["a futuristic city with flying cars", 1024, 1024, 5.0, 2.0, 18] ], inputs=[prompt, height, width, guidance_scale, pag_guidance_scale, num_inference_steps], outputs=output, fn=generate_image, ) generate_btn.click( fn=generate_image, inputs=[prompt, height, width, guidance_scale, pag_guidance_scale, num_inference_steps], outputs=output ) gr.Markdown("[link to model](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px)", elem_classes="center-text footer-link") # Launch the interface interface.launch()