AlekseyCalvin commited on
Commit
0e9fbd7
·
verified ·
1 Parent(s): 26b76a9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import json
4
+ import logging
5
+ import torch
6
+ from PIL import Image
7
+ import spaces
8
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
9
+ import copy
10
+ import random
11
+ import time
12
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
13
+ from huggingface_hub import HfFileSystem, ModelCard
14
+ from huggingface_hub import login, hf_hub_download
15
+ import safetensors.torch
16
+ from safetensors.torch import load_file
17
+ hf_token = os.environ.get("HF_TOKEN")
18
+ login(token=hf_token)
19
+
20
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
21
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
22
+ os.environ["HF_HUB_CACHE"] = cache_path
23
+ os.environ["HF_HOME"] = cache_path
24
+
25
+ torch.set_float32_matmul_precision("high")
26
+
27
+ # Load LoRAs from JSON file
28
+ with open('loras.json', 'r') as f:
29
+ loras = json.load(f)
30
+
31
+ # Initialize the base model
32
+ dtype = torch.bfloat16
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
36
+ good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
37
+
38
+ pipe = DiffusionPipeline.from_pretrained(
39
+ "jimmycarter/LibreFLUX",
40
+ custom_pipeline="jimmycarter/LibreFLUX",
41
+ use_safetensors=True,
42
+ torch_dtype=torch.bfloat16,
43
+ trust_remote_code=True,
44
+ ).to(device)
45
+
46
+ MAX_SEED = 2**32-1
47
+
48
+ class calculateDuration:
49
+ def __init__(self, activity_name=""):
50
+ self.activity_name = activity_name
51
+
52
+ def __enter__(self):
53
+ self.start_time = time.time()
54
+ return self
55
+
56
+ def __exit__(self, exc_type, exc_value, traceback):
57
+ self.end_time = time.time()
58
+ self.elapsed_time = self.end_time - self.start_time
59
+ if self.activity_name:
60
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
61
+ else:
62
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
63
+
64
+
65
+ def update_selection(evt: gr.SelectData, width, height):
66
+ selected_lora = loras[evt.index]
67
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
68
+ lora_repo = selected_lora["repo"]
69
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
70
+ if "aspect" in selected_lora:
71
+ if selected_lora["aspect"] == "portrait":
72
+ width = 768
73
+ height = 1024
74
+ elif selected_lora["aspect"] == "landscape":
75
+ width = 1024
76
+ height = 768
77
+ return (
78
+ gr.update(placeholder=new_placeholder),
79
+ updated_text,
80
+ evt.index,
81
+ width,
82
+ height,
83
+ )
84
+
85
+ @spaces.GPU(duration=70)
86
+ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress):
87
+ pipe.to("cuda")
88
+ generator = torch.Generator(device="cuda").manual_seed(seed)
89
+
90
+ with calculateDuration("Generating image"):
91
+ # Generate image
92
+ image = pipe(
93
+ prompt=f"{prompt} {trigger_word}",
94
+ num_inference_steps=steps,
95
+ guidance_scale=cfg_scale,
96
+ width=width,
97
+ height=height,
98
+ generator=generator,
99
+ joint_attention_kwargs={"scale": lora_scale},
100
+ ).images[0]
101
+ return image
102
+
103
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
104
+ if selected_index is None:
105
+ raise gr.Error("You must select a LoRA before proceeding.")
106
+
107
+ selected_lora = loras[selected_index]
108
+ lora_path = selected_lora["repo"]
109
+ trigger_word = selected_lora["trigger_word"]
110
+ if(trigger_word):
111
+ if "trigger_position" in selected_lora:
112
+ if selected_lora["trigger_position"] == "prepend":
113
+ prompt_mash = f"{trigger_word} {prompt}"
114
+ else:
115
+ prompt_mash = f"{prompt} {trigger_word}"
116
+ else:
117
+ prompt_mash = f"{trigger_word} {prompt}"
118
+ else:
119
+ prompt_mash = prompt
120
+
121
+ # Load LoRA weights
122
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
123
+ if "weights" in selected_lora:
124
+ pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
125
+ else:
126
+ pipe.load_lora_weights(lora_path)
127
+
128
+ # Set random seed for reproducibility
129
+ with calculateDuration("Randomizing seed"):
130
+ if randomize_seed:
131
+ seed = random.randint(0, MAX_SEED)
132
+
133
+ image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
134
+ pipe.to("cpu")
135
+ pipe.unload_lora_weights()
136
+ return image, seed
137
+
138
+ run_lora.zerogpu = True
139
+
140
+ css = '''
141
+ #gen_btn{height: 100%}
142
+ #title{text-align: center}
143
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
144
+ #title img{width: 100px; margin-right: 0.5em}
145
+ #gallery .grid-wrap{height: 10vh}
146
+ '''
147
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
148
+ title = gr.HTML(
149
+ """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> SOONfactory </h1>""",
150
+ elem_id="title",
151
+ )
152
+ # Info blob stating what the app is running
153
+ info_blob = gr.HTML(
154
+ """<div id="info_blob"> Novorealist LoRa-stocked Birthweek-inspired Img Manufactory for Dunova, Dunovas, & Dunovaists!</div>"""
155
+ )
156
+
157
+ # Info blob stating what the app is running
158
+ info_blob = gr.HTML(
159
+ """<div id="info_blob">Trigger LoRAs by Pre-phrasing Prompts w/: 1-5. ADU person (/'ADU woman') photo |6-15. HST style |16. how2draw |17-20.HST |21. HST Austin Osman Spare style |22. RCA |23. propaganda poster |24. SOTS art |25. pficonics |26. wh3r3sw4ld0 |27. vintage cover |28. crisp photo |29. retrofuturism |30. Film Photo |31. TOK hybrid |32. 2004 photo |33. Unexpected photo |34. flmft |35. TOK portra |36. Yearbook photo |37. Akhmatova |38. Tsvetaeva |39. Blok |40. LEN Lenin |41. Trotsky |42. Rosa Luxemburg </div>"""
160
+ )
161
+ selected_index = gr.State(None)
162
+ with gr.Row():
163
+ with gr.Column(scale=3):
164
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Select LoRa/Style & type prompt!")
165
+ with gr.Column(scale=1, elem_id="gen_column"):
166
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
167
+ with gr.Row():
168
+ with gr.Column(scale=3):
169
+ selected_info = gr.Markdown("")
170
+ gallery = gr.Gallery(
171
+ [(item["image"], item["title"]) for item in loras],
172
+ label="LoRA Inventory",
173
+ allow_preview=False,
174
+ columns=3,
175
+ elem_id="gallery"
176
+ )
177
+
178
+ with gr.Column(scale=4):
179
+ result = gr.Image(label="Generated Image")
180
+
181
+ with gr.Row():
182
+ with gr.Accordion("Advanced Settings", open=True):
183
+ with gr.Column():
184
+ with gr.Row():
185
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=20, step=0.5, value=3.0)
186
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=12)
187
+
188
+ with gr.Row():
189
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
190
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1088)
191
+
192
+ with gr.Row():
193
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
194
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
195
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2.0, step=0.01, value=1.05)
196
+
197
+ gallery.select(
198
+ update_selection,
199
+ inputs=[width, height],
200
+ outputs=[prompt, selected_info, selected_index, width, height]
201
+ )
202
+
203
+ gr.on(
204
+ triggers=[generate_button.click, prompt.submit],
205
+ fn=run_lora,
206
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
207
+ outputs=[result, seed]
208
+ )
209
+
210
+ app.queue(default_concurrency_limit=2).launch(show_error=True)
211
+ app.launch()