Spaces:
Running
on
Zero
Running
on
Zero
from typing import Tuple, Optional | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
import random | |
from PIL import Image | |
import json | |
import boto3 | |
from io import BytesIO | |
from datetime import datetime | |
from huggingface_hub import login | |
import os | |
from diffusers import FluxKontextPipeline | |
from diffusers.utils import load_image | |
from diffusers.utils import load_image, make_image_grid | |
from datetime import datetime | |
import time | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
login(token=HF_TOKEN) | |
MAX_SEED = np.iinfo(np.int32).max | |
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda") | |
class calculateDuration: | |
def __init__(self, activity_name=""): | |
self.activity_name = activity_name | |
def __enter__(self): | |
self.start_time = time.time() | |
self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time)) | |
print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}") | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.end_time = time.time() | |
self.elapsed_time = self.end_time - self.start_time | |
self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time)) | |
if self.activity_name: | |
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") | |
else: | |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds") | |
def infer( | |
input_image, | |
prompt, | |
seed, | |
randomize_seed, | |
guidance_scale, | |
steps, | |
progress | |
): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
if input_image: | |
draft_image = input_image.convert("RGB") | |
image = pipe( | |
image=draft_image, | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
width = draft_image.size[0], | |
height = draft_image.size[1], | |
num_inference_steps=steps, | |
generator=torch.Generator().manual_seed(seed), | |
).images[0] | |
else: | |
image = pipe( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=steps, | |
generator=torch.Generator().manual_seed(seed), | |
).images[0] | |
return image | |
def process(image_url, prompt, seed, randomize_seed, guidance_scale, steps, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)): | |
result = {"status": "false", "message": ""} | |
input_image = load_image(image_url) | |
if not isinstance(input_image, Image.Image): | |
result["status"] = "fail" | |
result["message"] = "Invalid input image url" | |
return json.dumps(result) | |
try: | |
generated_image = infer(input_image, prompt, seed, randomize_seed, guidance_scale, steps, progress) | |
except Exception as e: | |
result["status"] = "faield" | |
result["message"] = "generate image failed" | |
generated_image = None | |
if generated_image: | |
if upload_to_r2: | |
url = upload_image_to_r2(generated_image, account_id, access_key, secret_key, bucket) | |
result = {"status": "success", "message": "upload image success", "url": url} | |
else: | |
result = {"status": "success", "message": "Image generated but not uploaded"} | |
progress(100, "finish!") | |
return json.dumps(result) | |
def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name): | |
with calculateDuration("Upload image"): | |
print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name) | |
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com" | |
s3 = boto3.client( | |
's3', | |
endpoint_url=connectionUrl, | |
region_name='auto', | |
aws_access_key_id=access_key, | |
aws_secret_access_key=secret_key | |
) | |
current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S") | |
image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png" | |
buffer = BytesIO() | |
image.save(buffer, "PNG") | |
buffer.seek(0) | |
s3.upload_fileobj(buffer, bucket_name, image_file) | |
print("upload finish", image_file) | |
# start to generate thumbnail | |
thumbnail = image.copy() | |
thumbnail_width = 256 | |
aspect_ratio = image.height / image.width | |
thumbnail_height = int(thumbnail_width * aspect_ratio) | |
thumbnail = thumbnail.resize((thumbnail_width, thumbnail_height), Image.LANCZOS) | |
# Generate the thumbnail image filename | |
thumbnail_file = image_file.replace(".png", "_thumbnail.png") | |
# Save thumbnail to buffer and upload | |
thumbnail_buffer = BytesIO() | |
thumbnail.save(thumbnail_buffer, "PNG") | |
thumbnail_buffer.seek(0) | |
s3.upload_fileobj(thumbnail_buffer, bucket_name, thumbnail_file) | |
print("upload thumbnail finish", thumbnail_file) | |
return image_file | |
def dummy(image_url, prompt, seed, randomize_seed, guidance_scale, steps, upload_to_r2, account_id, access_key, secret_key, bucket): | |
# 返回一张纯黑图和空json,安全无异常 | |
black = Image.new("RGB", (256,256)) | |
return [black], '{"status":"dummy"}' | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown(f"# FLUX.1 Kontext [dev]") | |
with gr.Row(): | |
with gr.Column(): | |
image_url = gr.Textbox( | |
label="Orginal image url", | |
show_label=True, | |
max_lines=1, | |
placeholder="Enter image url for inpainting", | |
container=False | |
) | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')", | |
container=False, | |
) | |
run_button = gr.Button("Run") | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1, | |
maximum=10, | |
step=0.1, | |
value=2.5, | |
) | |
steps = gr.Slider( | |
label="Steps", | |
minimum=1, | |
maximum=30, | |
value=28, | |
step=1 | |
) | |
with gr.Accordion("R2 Settings", open=False): | |
upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False) | |
with gr.Row(): | |
account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id", value="") | |
bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here", value="") | |
with gr.Row(): | |
access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here", value="") | |
secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here", value="") | |
with gr.Column(): | |
output_json_component = gr.Code(label="JSON Result", language="json", value="{}") | |
run_button.click( | |
fn=process, | |
inputs=[ | |
image_url, | |
prompt, | |
seed, | |
randomize_seed, | |
guidance_scale, | |
steps, | |
upload_to_r2, | |
account_id, | |
access_key, | |
secret_key, | |
bucket | |
], | |
outputs = [ | |
output_json_component | |
], | |
api_name="predict" | |
) | |
demo.queue(api_open=True) | |
demo.launch(share=True) | |