Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from gradio_client import Client | |
| from diffusers import AutoencoderKL, StableDiffusionXLPipeline | |
| import torch | |
| import concurrent.futures | |
| import spaces | |
| client_lightning = Client("AP123/SDXL-Lightning") | |
| client_hyper = Client("ByteDance/Hyper-SDXL-1Step-T2I") | |
| vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
| ### SDXL Turbo #### | |
| pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo", | |
| vae=vae, | |
| torch_dtype=torch.float16, | |
| variant="fp16" | |
| ) | |
| pipe_turbo.to("cuda") | |
| def get_lighting_result(prompt): | |
| result_lighting = client_lightning.predict( | |
| prompt, # Your prompt | |
| "1-Step", # Number of inference steps | |
| api_name="/generate_image" | |
| ) | |
| return result_lighting | |
| def get_hyper_result(prompt): | |
| result_hyper = client_hyper.predict( | |
| num_images=1, | |
| height=1024, | |
| width=1024, | |
| prompt=prompt, | |
| seed=3413, | |
| api_name="/process_image" | |
| ) | |
| return result_hyper | |
| def get_turbo_result(prompt): | |
| image_turbo = pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0] | |
| return image_turbo | |
| def run_comparison(prompt): | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| # Submit tasks to the executor | |
| future_lighting = executor.submit(get_lighting_result, prompt) | |
| future_hyper = executor.submit(get_hyper_result, prompt) | |
| future_turbo = executor.submit(get_turbo_result, prompt) | |
| # Wait for all futures to complete | |
| results = concurrent.futures.wait( | |
| [future_lighting, future_hyper, future_turbo], | |
| return_when=concurrent.futures.ALL_COMPLETED | |
| ) | |
| # Extract results from futures | |
| result_lighting = future_lighting.result() | |
| result_hyper = future_hyper.result() | |
| image_turbo = future_turbo.result() | |
| print(result_lighting) | |
| print(result_hyper) | |
| return image_turbo, result_lighting, result_hyper | |
| css = ''' | |
| .gradio-container{max-width: 768px !important} | |
| ''' | |
| with gr.Blocks(css=css) as demo: | |
| prompt = gr.Textbox(label="Prompt") | |
| run = gr.Button("Run") | |
| with gr.Row(): | |
| image_turbo = gr.Image(label="SDXL Turbo") | |
| image_lightning = gr.Image(label="SDXL Lightning") | |
| image_hyper = gr.Image("Hyper SDXL") | |
| run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper]) | |
| demo.launch() |