ovi054 commited on
Commit
c028ee2
·
verified ·
1 Parent(s): 21b7cc6

Create demo.py

Browse files
Files changed (1) hide show
  1. demo.py +199 -0
demo.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ # import spaces
5
+ import torch
6
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
7
+ # from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
8
+ from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
9
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
10
+ import multiprocessing as mp
11
+
12
+ import os
13
+ import requests
14
+ import tempfile
15
+ import shutil
16
+ from urllib.parse import urlparse
17
+
18
+ dtype = torch.bfloat16
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ #black-forest-labs/FLUX.1-Krea-dev
21
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
22
+ good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
23
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
24
+ # srpo_128_base_oficial_model_fp16.safetensors
25
+ # pipe.load_lora_weights('Alissonerdx/flux.1-dev-SRPO-LoRas', weight_name='srpo_16_base_oficial_model_fp16.safetensors')
26
+ # pipe.fuse_lora()
27
+ torch.cuda.empty_cache()
28
+
29
+ MAX_SEED = np.iinfo(np.int32).max
30
+ MAX_IMAGE_SIZE = 2048
31
+
32
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
33
+
34
+ def load_lora_auto(pipe, lora_input):
35
+ lora_input = lora_input.strip()
36
+ if not lora_input:
37
+ return
38
+
39
+ # If it's just an ID like "author/model"
40
+ if "/" in lora_input and not lora_input.startswith("http"):
41
+ pipe.load_lora_weights(lora_input)
42
+ return
43
+
44
+ if lora_input.startswith("http"):
45
+ url = lora_input
46
+
47
+ # Repo page (no blob/resolve)
48
+ if "huggingface.co" in url and "/blob/" not in url and "/resolve/" not in url:
49
+ repo_id = urlparse(url).path.strip("/")
50
+ pipe.load_lora_weights(repo_id)
51
+ return
52
+
53
+ # Blob link → convert to resolve link
54
+ if "/blob/" in url:
55
+ url = url.replace("/blob/", "/resolve/")
56
+
57
+ # Download direct file
58
+ tmp_dir = tempfile.mkdtemp()
59
+ local_path = os.path.join(tmp_dir, os.path.basename(urlparse(url).path))
60
+
61
+ try:
62
+ print(f"Downloading LoRA from {url}...")
63
+ resp = requests.get(url, stream=True)
64
+ resp.raise_for_status()
65
+ with open(local_path, "wb") as f:
66
+ for chunk in resp.iter_content(chunk_size=8192):
67
+ f.write(chunk)
68
+ print(f"Saved LoRA to {local_path}")
69
+ pipe.load_lora_weights(local_path)
70
+ finally:
71
+ shutil.rmtree(tmp_dir, ignore_errors=True)
72
+
73
+ # @spaces.GPU(duration=25)
74
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, lora_id=None, lora_scale=0.95, progress=gr.Progress(track_tqdm=True)):
75
+ if randomize_seed:
76
+ seed = random.randint(0, MAX_SEED)
77
+ generator = torch.Generator().manual_seed(seed)
78
+
79
+
80
+ # for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
81
+ # prompt=prompt,
82
+ # guidance_scale=guidance_scale,
83
+ # num_inference_steps=num_inference_steps,
84
+ # width=width,
85
+ # height=height,
86
+ # generator=generator,
87
+ # output_type="pil",
88
+ # good_vae=good_vae,
89
+ # ):
90
+ # yield img, seed
91
+
92
+ # Handle LoRA loading
93
+ # Load LoRA weights and prepare joint_attention_kwargs
94
+ if lora_id and lora_id.strip() != "":
95
+ pipe.unload_lora_weights()
96
+ # pipe.load_lora_weights(lora_id.strip())
97
+ load_lora_auto(pipe, lora_id.strip())
98
+ joint_attention_kwargs = {"scale": lora_scale}
99
+ else:
100
+ joint_attention_kwargs = None
101
+
102
+ # apply_cache_on_pipe(
103
+ # pipe,
104
+ # # residual_diff_threshold=0.2,
105
+ # )
106
+
107
+ try:
108
+ # Call the custom pipeline function with the correct keyword argument
109
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
110
+ prompt=prompt,
111
+ guidance_scale=guidance_scale,
112
+ num_inference_steps=num_inference_steps,
113
+ width=width,
114
+ height=height,
115
+ generator=generator,
116
+ output_type="pil",
117
+ good_vae=good_vae, # Assuming good_vae is defined elsewhere
118
+ joint_attention_kwargs=joint_attention_kwargs, # Fixed parameter name
119
+ ):
120
+ yield img, seed
121
+ finally:
122
+ # Unload LoRA weights if they were loaded
123
+ if lora_id:
124
+ pipe.unload_lora_weights()
125
+
126
+ examples = [
127
+ "a tiny astronaut hatching from an egg on the moon",
128
+ "a cat holding a sign that says hello world",
129
+ "an anime illustration of a wiener schnitzel",
130
+ ]
131
+
132
+ css = """
133
+ #col-container {
134
+ margin: 0 auto;
135
+ max-width: 960px;
136
+ }
137
+ .generate-btn {
138
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
139
+ border: none !important;
140
+ color: white !important;
141
+ }
142
+ .generate-btn:hover {
143
+ transform: translateY(-2px);
144
+ box-shadow: 0 5px 15px rgba(0,0,0,0.2);
145
+ }
146
+ """
147
+
148
+ with gr.Blocks(css=css) as app:
149
+ gr.HTML("<center><h1>FLUX.1-Dev with LoRA support</h1></center>")
150
+ with gr.Column(elem_id="col-container"):
151
+ with gr.Row():
152
+ with gr.Column():
153
+ with gr.Row():
154
+ text_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here", lines=3, elem_id="prompt-text-input")
155
+ with gr.Row():
156
+ custom_lora = gr.Textbox(label="Custom LoRA (optional)", info="LoRA Hugging Face path", placeholder="multimodalart/vintage-ads-flux")
157
+ with gr.Row():
158
+ with gr.Accordion("Advanced Settings", open=False):
159
+ lora_scale = gr.Slider(
160
+ label="LoRA Scale",
161
+ minimum=0,
162
+ maximum=2,
163
+ step=0.01,
164
+ value=0.95,
165
+ )
166
+ with gr.Row():
167
+ width = gr.Slider(label="Width", value=1024, minimum=64, maximum=2048, step=8)
168
+ height = gr.Slider(label="Height", value=1024, minimum=64, maximum=2048, step=8)
169
+ seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=4294967296, step=1)
170
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
171
+ with gr.Row():
172
+ steps = gr.Slider(label="Inference steps steps", value=28, minimum=1, maximum=100, step=1)
173
+ cfg = gr.Slider(label="Guidance Scale", value=3.5, minimum=1, maximum=20, step=0.5)
174
+ # method = gr.Radio(label="Sampling method", value="DPM++ 2M Karras", choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"])
175
+
176
+ with gr.Row():
177
+ # text_button = gr.Button("Run", variant='primary', elem_id="gen-button")
178
+ text_button = gr.Button("✨ Generate Image", variant='primary', elem_classes=["generate-btn"])
179
+ with gr.Column():
180
+ with gr.Row():
181
+ image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
182
+
183
+ # gr.Markdown(article_text)
184
+ with gr.Column():
185
+ gr.Examples(
186
+ examples = examples,
187
+ inputs = [text_prompt],
188
+ )
189
+ gr.on(
190
+ triggers=[text_button.click, text_prompt.submit],
191
+ fn = infer,
192
+ inputs=[text_prompt, seed, randomize_seed, width, height, cfg, steps, custom_lora, lora_scale],
193
+ outputs=[image_output, seed]
194
+ )
195
+
196
+ # text_button.click(query, inputs=[custom_lora, text_prompt, steps, cfg, randomize_seed, seed, width, height], outputs=[image_output,seed_output, seed])
197
+ # text_button.click(infer, inputs=[text_prompt, seed, randomize_seed, width, height, cfg, steps, custom_lora, lora_scale], outputs=[image_output,seed_output, seed])
198
+
199
+ app.launch(share=True)