Spaces:
Runtime error
Runtime error
File size: 6,686 Bytes
e3164da f0e2bd1 e3164da 7650250 3e2182a de8afe4 a021913 ad1f99d 6d39491 3cabadc 6d39491 9234004 3cabadc 9234004 3cabadc f0e2bd1 74f4c1e 1d7ab4c 9452c41 3cabadc 6d39491 1d7ab4c 1042ff4 3cabadc ad1f99d fe5ff04 ad1f99d fe5ff04 67e5720 3e2182a 6a499ee 3e2182a 3cabadc 3e2182a 3cabadc 3e2182a 3cabadc 3e2182a f0e2bd1 ad1f99d 3cabadc ad1f99d 3cabadc ad1f99d fe5ff04 5fd12b3 3e2182a ad1f99d 6d39491 ad1f99d 6d39491 ad1f99d 6d39491 ad1f99d 3cabadc ad1f99d 6d39491 7650250 3cabadc de8afe4 3cabadc de8afe4 6d39491 3cabadc 6d39491 de8afe4 6d39491 6a499ee 3cabadc 6d39491 3cabadc 7650250 6a499ee 7650250 3cabadc 6d39491 ad1f99d 6d39491 7650250 3cabadc de8afe4 3cabadc 7650250 6a499ee 67e5720 ad1f99d 1d7ab4c ad1f99d 1042ff4 af3ee62 1042ff4 67e5720 1042ff4 1d7ab4c 1042ff4 1d7ab4c 1042ff4 1d7ab4c 1042ff4 6d39491 67e5720 6d39491 1042ff4 c83e1ca 9234004 de8afe4 3cabadc ad1f99d a021913 3cabadc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import os
import random
import gradio as gr
from huggingface_hub import login, hf_hub_download
import spaces
import torch
from diffusers import DiffusionPipeline
import hashlib
import pickle
import yaml
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Load config file
with open('config.yaml', 'r') as file:
config = yaml.safe_load(file)
# Authenticate using the token stored in Hugging Face Spaces secrets
if 'HF_TOKEN' in os.environ:
login(token=os.environ['HF_TOKEN'])
logging.info("Successfully logged in with HF_TOKEN")
else:
logging.warning("HF_TOKEN not found in environment variables. Some functionality may be limited.")
# Correctly access the config values
process_config = config['config']['process'][0] # Assuming the first process is the one we want
base_model = "black-forest-labs/FLUX.1-dev"
lora_model = "sagar007/sagar_flux" # This isn't in the config, so we're keeping it as is
trigger_word = process_config['trigger_word']
logging.info(f"Base model: {base_model}")
logging.info(f"LoRA model: {lora_model}")
logging.info(f"Trigger word: {trigger_word}")
# Global variables
pipe = None
cache = {}
CACHE_FILE = "image_cache.pkl"
# Example prompts
example_prompts = [
"Photos of sagar as superman flying in the sky, cape billowing in the wind, sagar",
"Professional photo of sagar for LinkedIn headshot, DSLR quality, neutral background, sagar",
"Sagar as an astronaut exploring a distant alien planet, vibrant colors, sagar",
"Sagar hiking in a lush green forest, sunlight filtering through the trees, sagar",
"Sagar as a wizard casting a spell, magical energy swirling around, sagar",
"Sagar scoring a goal in a dramatic soccer match, stadium lights shining, sagar",
"Sagar as a Roman emperor addressing a crowd, wearing a toga and laurel wreath, sagar"
]
def initialize_model():
global pipe
if pipe is None:
try:
logging.info(f"Attempting to load model: {base_model}")
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16, use_safetensors=True)
logging.info("Moving model to CUDA...")
pipe = pipe.to("cuda")
logging.info(f"Successfully loaded model: {base_model}")
except Exception as e:
logging.error(f"Error loading model {base_model}: {str(e)}")
raise
def load_cache():
global cache
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'rb') as f:
cache = pickle.load(f)
logging.info(f"Loaded {len(cache)} cached images")
def save_cache():
with open(CACHE_FILE, 'wb') as f:
pickle.dump(cache, f)
logging.info(f"Saved {len(cache)} cached images")
def get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale):
return hashlib.md5(f"{prompt}{cfg_scale}{steps}{seed}{width}{height}{lora_scale}".encode()).hexdigest()
@spaces.GPU(duration=80)
def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale):
global pipe, cache
if randomize_seed:
seed = random.randint(0, 2**32-1)
cache_key = get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale)
if cache_key in cache:
logging.info("Using cached image")
return cache[cache_key], seed
try:
logging.info(f"Starting run_lora with prompt: {prompt}")
if pipe is None:
logging.info("Initializing model...")
initialize_model()
logging.info(f"Using seed: {seed}")
generator = torch.Generator(device="cuda").manual_seed(seed)
full_prompt = f"{prompt} {trigger_word}"
logging.info(f"Full prompt: {full_prompt}")
logging.info("Starting image generation...")
image = pipe(
prompt=full_prompt,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
).images[0]
logging.info("Image generation completed successfully")
# Cache the generated image
cache[cache_key] = image
save_cache()
return image, seed
except Exception as e:
logging.error(f"Error during generation: {str(e)}")
import traceback
logging.error(traceback.format_exc())
return None, seed
def update_prompt(example):
return example
# Load cache at startup
load_cache()
# Pre-generate and cache example images
def cache_example_images():
for prompt in example_prompts:
run_lora(prompt, process_config['sample']['guidance_scale'], process_config['sample']['sample_steps'],
process_config['sample']['walk_seed'], process_config['sample']['seed'],
process_config['sample']['width'], process_config['sample']['height'], 0.75)
# Gradio interface setup
with gr.Blocks() as app:
gr.Markdown("# Text-to-Image Generation with FLUX (ZeroGPU)")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
example_dropdown = gr.Dropdown(choices=example_prompts, label="Example Prompts")
run_button = gr.Button("Generate")
with gr.Column():
result = gr.Image(label="Result")
with gr.Row():
cfg_scale = gr.Slider(minimum=1, maximum=20, value=process_config['sample']['guidance_scale'], step=0.1, label="CFG Scale")
steps = gr.Slider(minimum=1, maximum=100, value=process_config['sample']['sample_steps'], step=1, label="Steps")
with gr.Row():
width = gr.Slider(minimum=128, maximum=1024, value=process_config['sample']['width'], step=64, label="Width")
height = gr.Slider(minimum=128, maximum=1024, value=process_config['sample']['height'], step=64, label="Height")
with gr.Row():
seed = gr.Number(label="Seed", value=process_config['sample']['seed'], precision=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=process_config['sample']['walk_seed'])
lora_scale = gr.Slider(minimum=0, maximum=1, value=0.75, step=0.01, label="LoRA Scale")
example_dropdown.change(update_prompt, inputs=[example_dropdown], outputs=[prompt])
run_button.click(
run_lora,
inputs=[prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
outputs=[result, seed]
)
# Launch the app
if __name__ == "__main__":
logging.info("Starting the Gradio app...")
logging.info("Pre-generating example images...")
cache_example_images()
app.launch(share=True)
logging.info("Gradio app launched successfully") |