File size: 863 Bytes
24ec8bb
4e717d6
 
8c2e0d0
4e717d6
9544e60
674e245
69667cb
4e717d6
 
bf738a2
69667cb
4f27510
24ec8bb
4e717d6
bf738a2
69667cb
24ec8bb
bf738a2
 
 
 
 
 
9544e60
 
da628cb
 
 
bf738a2
da628cb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from datetime import datetime

import gradio as gr
import spaces
import torch
from diffusers import FluxPipeline

from optimization import optimize_pipeline_


pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16).to('cuda')
optimize_pipeline_(pipeline, "prompt")


@spaces.GPU
def generate_image(prompt: str, progress=gr.Progress(track_tqdm=True)):
    generator = torch.Generator(device='cuda').manual_seed(42)
    t0 = datetime.now()
    output = pipeline(
        prompt=prompt,
        num_inference_steps=28,
        generator=generator,
    )
    return (output.images[0], f'{(datetime.now() - t0).total_seconds():.2f}s')


gr.Interface(
    fn=generate_image,
    inputs=gr.Text(label="Prompt"),
    outputs=gr.Image(),
    examples=["A cat playing with a ball of yarn"],
    cache_examples=False,
).launch()