File size: 6,686 Bytes
e3164da
f0e2bd1
e3164da
7650250
3e2182a
de8afe4
a021913
ad1f99d
 
6d39491
3cabadc
 
 
 
6d39491
 
 
 
9234004
 
 
 
3cabadc
9234004
3cabadc
f0e2bd1
74f4c1e
1d7ab4c
9452c41
3cabadc
6d39491
1d7ab4c
1042ff4
3cabadc
 
 
 
ad1f99d
fe5ff04
ad1f99d
 
fe5ff04
67e5720
 
 
 
 
 
 
 
 
 
 
3e2182a
6a499ee
3e2182a
 
3cabadc
 
 
3e2182a
3cabadc
3e2182a
3cabadc
3e2182a
f0e2bd1
ad1f99d
 
 
 
 
3cabadc
ad1f99d
 
 
 
3cabadc
ad1f99d
 
 
fe5ff04
5fd12b3
3e2182a
ad1f99d
6d39491
ad1f99d
 
6d39491
ad1f99d
6d39491
ad1f99d
3cabadc
ad1f99d
6d39491
7650250
3cabadc
de8afe4
3cabadc
de8afe4
6d39491
3cabadc
6d39491
de8afe4
6d39491
6a499ee
3cabadc
6d39491
3cabadc
7650250
6a499ee
7650250
 
 
 
 
 
3cabadc
6d39491
ad1f99d
 
 
6d39491
7650250
 
3cabadc
de8afe4
3cabadc
7650250
6a499ee
67e5720
 
 
ad1f99d
 
 
 
 
 
1d7ab4c
 
 
ad1f99d
1042ff4
 
af3ee62
1042ff4
 
 
67e5720
1042ff4
 
 
 
1d7ab4c
 
1042ff4
1d7ab4c
 
1042ff4
1d7ab4c
 
1042ff4
6d39491
67e5720
6d39491
1042ff4
 
 
 
 
c83e1ca
9234004
de8afe4
3cabadc
 
ad1f99d
a021913
3cabadc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import random
import gradio as gr
from huggingface_hub import login, hf_hub_download
import spaces
import torch
from diffusers import DiffusionPipeline
import hashlib
import pickle
import yaml
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Load config file
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

# Authenticate using the token stored in Hugging Face Spaces secrets
if 'HF_TOKEN' in os.environ:
    login(token=os.environ['HF_TOKEN'])
    logging.info("Successfully logged in with HF_TOKEN")
else:
    logging.warning("HF_TOKEN not found in environment variables. Some functionality may be limited.")

# Correctly access the config values
process_config = config['config']['process'][0]  # Assuming the first process is the one we want

base_model = "black-forest-labs/FLUX.1-dev"
lora_model = "sagar007/sagar_flux"  # This isn't in the config, so we're keeping it as is
trigger_word = process_config['trigger_word']

logging.info(f"Base model: {base_model}")
logging.info(f"LoRA model: {lora_model}")
logging.info(f"Trigger word: {trigger_word}")

# Global variables
pipe = None
cache = {}
CACHE_FILE = "image_cache.pkl"

# Example prompts
example_prompts = [
    "Photos of sagar as superman flying in the sky, cape billowing in the wind, sagar",
    "Professional photo of sagar for LinkedIn headshot, DSLR quality, neutral background, sagar",
    "Sagar as an astronaut exploring a distant alien planet, vibrant colors, sagar",
    "Sagar hiking in a lush green forest, sunlight filtering through the trees, sagar",
    "Sagar as a wizard casting a spell, magical energy swirling around, sagar",
    "Sagar scoring a goal in a dramatic soccer match, stadium lights shining, sagar",
    "Sagar as a Roman emperor addressing a crowd, wearing a toga and laurel wreath, sagar"
]

def initialize_model():
    global pipe
    if pipe is None:
        try:
            logging.info(f"Attempting to load model: {base_model}")
            pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16, use_safetensors=True)
            logging.info("Moving model to CUDA...")
            pipe = pipe.to("cuda")
            logging.info(f"Successfully loaded model: {base_model}")
        except Exception as e:
            logging.error(f"Error loading model {base_model}: {str(e)}")
            raise

def load_cache():
    global cache
    if os.path.exists(CACHE_FILE):
        with open(CACHE_FILE, 'rb') as f:
            cache = pickle.load(f)
    logging.info(f"Loaded {len(cache)} cached images")

def save_cache():
    with open(CACHE_FILE, 'wb') as f:
        pickle.dump(cache, f)
    logging.info(f"Saved {len(cache)} cached images")

def get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale):
    return hashlib.md5(f"{prompt}{cfg_scale}{steps}{seed}{width}{height}{lora_scale}".encode()).hexdigest()

@spaces.GPU(duration=80)
def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale):
    global pipe, cache

    if randomize_seed:
        seed = random.randint(0, 2**32-1)

    cache_key = get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale)

    if cache_key in cache:
        logging.info("Using cached image")
        return cache[cache_key], seed

    try:
        logging.info(f"Starting run_lora with prompt: {prompt}")
        if pipe is None:
            logging.info("Initializing model...")
            initialize_model()

        logging.info(f"Using seed: {seed}")

        generator = torch.Generator(device="cuda").manual_seed(seed)

        full_prompt = f"{prompt} {trigger_word}"
        logging.info(f"Full prompt: {full_prompt}")

        logging.info("Starting image generation...")
        image = pipe(
            prompt=full_prompt,
            num_inference_steps=steps,
            guidance_scale=cfg_scale,
            width=width,
            height=height,
            generator=generator,
        ).images[0]
        logging.info("Image generation completed successfully")

        # Cache the generated image
        cache[cache_key] = image
        save_cache()

        return image, seed
    except Exception as e:
        logging.error(f"Error during generation: {str(e)}")
        import traceback
        logging.error(traceback.format_exc())
        return None, seed

def update_prompt(example):
    return example

# Load cache at startup
load_cache()

# Pre-generate and cache example images
def cache_example_images():
    for prompt in example_prompts:
        run_lora(prompt, process_config['sample']['guidance_scale'], process_config['sample']['sample_steps'], 
                 process_config['sample']['walk_seed'], process_config['sample']['seed'], 
                 process_config['sample']['width'], process_config['sample']['height'], 0.75)

# Gradio interface setup
with gr.Blocks() as app:
    gr.Markdown("# Text-to-Image Generation with FLUX (ZeroGPU)")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt")
            example_dropdown = gr.Dropdown(choices=example_prompts, label="Example Prompts")
            run_button = gr.Button("Generate")
        with gr.Column():
            result = gr.Image(label="Result")
    with gr.Row():
        cfg_scale = gr.Slider(minimum=1, maximum=20, value=process_config['sample']['guidance_scale'], step=0.1, label="CFG Scale")
        steps = gr.Slider(minimum=1, maximum=100, value=process_config['sample']['sample_steps'], step=1, label="Steps")
    with gr.Row():
        width = gr.Slider(minimum=128, maximum=1024, value=process_config['sample']['width'], step=64, label="Width")
        height = gr.Slider(minimum=128, maximum=1024, value=process_config['sample']['height'], step=64, label="Height")
    with gr.Row():
        seed = gr.Number(label="Seed", value=process_config['sample']['seed'], precision=0)
        randomize_seed = gr.Checkbox(label="Randomize seed", value=process_config['sample']['walk_seed'])
    lora_scale = gr.Slider(minimum=0, maximum=1, value=0.75, step=0.01, label="LoRA Scale")

    example_dropdown.change(update_prompt, inputs=[example_dropdown], outputs=[prompt])

    run_button.click(
        run_lora,
        inputs=[prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
        outputs=[result, seed]
    )

# Launch the app
if __name__ == "__main__":
    logging.info("Starting the Gradio app...")
    logging.info("Pre-generating example images...")
    cache_example_images()
    app.launch(share=True)
    logging.info("Gradio app launched successfully")