jiuface commited on
Commit
72bb838
·
verified ·
1 Parent(s): 403197b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional
2
+ import gradio as gr
3
+ import numpy as np
4
+ import spaces
5
+ import torch
6
+ import random
7
+ from PIL import Image
8
+ import json
9
+ import boto3
10
+ from io import BytesIO
11
+ from datetime import datetime
12
+ from huggingface_hub import login
13
+ import os
14
+
15
+ from diffusers import FluxKontextPipeline
16
+ from diffusers.utils import load_image
17
+ from diffusers.utils import load_image, make_image_grid
18
+
19
+ HF_TOKEN = os.environ.get("HF_TOKEN")
20
+ login(token=HF_TOKEN)
21
+
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+
24
+ pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
25
+
26
+ @spaces.GPU
27
+ def infer(
28
+ input_image: Image.Image,
29
+ prompt: str,
30
+ seed: int,
31
+ randomize_seed: bool,
32
+ guidance_scale: float,
33
+ steps: int,
34
+ progress
35
+ ):
36
+ if randomize_seed:
37
+ seed = random.randint(0, MAX_SEED)
38
+ if input_image:
39
+ input_image = input_image.convert("RGB")
40
+ image = pipe(
41
+ image=input_image,
42
+ prompt=prompt,
43
+ guidance_scale=guidance_scale,
44
+ width = input_image.size[0],
45
+ height = input_image.size[1],
46
+ num_inference_steps=steps,
47
+ generator=torch.Generator().manual_seed(seed),
48
+ ).images[0]
49
+ else:
50
+ image = pipe(
51
+ prompt=prompt,
52
+ guidance_scale=guidance_scale,
53
+ num_inference_steps=steps,
54
+ generator=torch.Generator().manual_seed(seed),
55
+ ).images[0]
56
+ return image
57
+
58
+
59
+ def process(
60
+ image_url: str,
61
+ prompt: str,
62
+ seed:int,
63
+ randomize_seed:bool,
64
+ guidance_scale:float,
65
+ steps:int,
66
+ upload_to_r2:bool,
67
+ account_id: str,
68
+ access_key: str,
69
+ secret_key: str,
70
+ bucket:str,
71
+ progress=gr.Progress(track_tqdm=True)
72
+ ):
73
+ result = {"status": "false", "message": ""}
74
+ input_image = load_image(image_url)
75
+ if not isinstance(input_image, Image.Image):
76
+ result["status"] = "fail"
77
+ result["message"] = "Invalid input image url"
78
+ return [], json.dumps(result)
79
+
80
+ try:
81
+ generated_image = infer(input_image, prompt, seed, randomize_seed, guidance_scale, steps, progress)
82
+ except Exception as e:
83
+ result["status"] = "faield"
84
+ result["message"] = "generate image failed"
85
+ generated_image = None
86
+
87
+ if generated_image:
88
+ if upload_to_r2:
89
+ url = upload_image_to_r2(generated_image, account_id, access_key, secret_key, bucket)
90
+ result = {"status": "success", "message": "upload image success", "url": url}
91
+ else:
92
+ result = {"status": "success", "message": "Image generated but not uploaded"}
93
+ final_images = []
94
+
95
+ if isinstance(input_image, Image.Image):
96
+ final_images.append(input_image)
97
+ if isinstance(generated_image, Image.Image):
98
+ final_images.append(generated_image)
99
+
100
+ progress(100, "finish!")
101
+ return final_images, json.dumps(result)
102
+
103
+
104
+ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
105
+ with calculateDuration("Upload image"):
106
+ print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
107
+ connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
108
+ s3 = boto3.client(
109
+ 's3',
110
+ endpoint_url=connectionUrl,
111
+ region_name='auto',
112
+ aws_access_key_id=access_key,
113
+ aws_secret_access_key=secret_key
114
+ )
115
+ current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S")
116
+ image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
117
+ buffer = BytesIO()
118
+ image.save(buffer, "PNG")
119
+ buffer.seek(0)
120
+ s3.upload_fileobj(buffer, bucket_name, image_file)
121
+ print("upload finish", image_file)
122
+
123
+ # start to generate thumbnail
124
+ thumbnail = image.copy()
125
+ thumbnail_width = 256
126
+ aspect_ratio = image.height / image.width
127
+ thumbnail_height = int(thumbnail_width * aspect_ratio)
128
+ thumbnail = thumbnail.resize((thumbnail_width, thumbnail_height), Image.LANCZOS)
129
+
130
+ # Generate the thumbnail image filename
131
+ thumbnail_file = image_file.replace(".png", "_thumbnail.png")
132
+
133
+ # Save thumbnail to buffer and upload
134
+ thumbnail_buffer = BytesIO()
135
+ thumbnail.save(thumbnail_buffer, "PNG")
136
+ thumbnail_buffer.seek(0)
137
+ s3.upload_fileobj(thumbnail_buffer, bucket_name, thumbnail_file)
138
+ print("upload thumbnail finish", thumbnail_file)
139
+
140
+ return image_file
141
+
142
+
143
+ with gr.Blocks() as demo:
144
+
145
+ with gr.Column():
146
+ gr.Markdown(f"""# FLUX.1 Kontext [dev]
147
+ Image editing and manipulation model guidance-distilled from FLUX.1 Kontext [pro], [[blog]](https://bfl.ai/announcements/flux-1-kontext-dev) [[model]](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev)
148
+ """)
149
+ with gr.Row():
150
+ with gr.Column():
151
+ image_url = gr.Text(
152
+ label="Orginal image url",
153
+ show_label=True,
154
+ max_lines=1,
155
+ placeholder="Enter image url for inpainting",
156
+ container=False
157
+ )
158
+ with gr.Row():
159
+ prompt = gr.Text(
160
+ label="Prompt",
161
+ show_label=False,
162
+ max_lines=1,
163
+ placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')",
164
+ container=False,
165
+ )
166
+ run_button = gr.Button(value='Submit', variant='primary', scale=0)
167
+
168
+ with gr.Accordion("Advanced Settings", open=False):
169
+
170
+ seed = gr.Slider(
171
+ label="Seed",
172
+ minimum=0,
173
+ maximum=MAX_SEED,
174
+ step=1,
175
+ value=0,
176
+ )
177
+
178
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
179
+
180
+ guidance_scale = gr.Slider(
181
+ label="Guidance Scale",
182
+ minimum=1,
183
+ maximum=10,
184
+ step=0.1,
185
+ value=2.5,
186
+ )
187
+
188
+ steps = gr.Slider(
189
+ label="Steps",
190
+ minimum=1,
191
+ maximum=30,
192
+ value=28,
193
+ step=1
194
+ )
195
+
196
+ with gr.Accordion("R2 Settings", open=False):
197
+ upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False)
198
+ with gr.Row():
199
+ account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id")
200
+ bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
201
+
202
+ with gr.Row():
203
+ access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here")
204
+ secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
205
+
206
+
207
+ with gr.Column():
208
+ generated_images = gr.Gallery(label="Result", show_label=True)
209
+ output_json_component = gr.Code(label="JSON Result", language="json")
210
+
211
+ run_button.click(
212
+ fn = process,
213
+ inputs = [
214
+ image_url,
215
+ prompt,
216
+ seed,
217
+ randomize_seed,
218
+ guidance_scale,
219
+ steps,
220
+ upload_to_r2,
221
+ account_id,
222
+ access_key,
223
+ secret_key,
224
+ bucket
225
+ ],
226
+ outputs = [
227
+ generated_images,
228
+ output_json_component
229
+ ]
230
+ )
231
+
232
+
233
+ demo.queue()
234
+ demo.launch(share=True)