File size: 2,690 Bytes
8c2e0d0
 
a3d55a6
8c2e0d0
ee7158f
 
4e717d6
ee7158f
 
a3d55a6
8c2e0d0
a3d55a6
ee7158f
4e717d6
 
 
 
8c2e0d0
4e717d6
 
9544e60
674e245
 
4e717d6
 
9544e60
a3d55a6
9544e60
4e717d6
 
50d5e76
9544e60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674e245
9544e60
060b6b6
c29a475
674e245
 
4f27510
4e717d6
a3d55a6
 
 
 
 
 
 
9544e60
 
a3d55a6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
"""
from datetime import datetime

# Upgrade PyTorch
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 torch torchvision spaces')

# CUDA toolkit install
t0 = datetime.now()
from utils.cuda_toolkit import install_cuda_toolkit; install_cuda_toolkit()
print('install_cuda_toolkit', -(t0 - (t0 := datetime.now())))

# Actual app.py
import os

import gradio as gr
import spaces
import torch
import torch._inductor
from diffusers import FluxPipeline

from .utils.zerogpu import aoti_compile


pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
print('FluxPipeline.from_pretrained', -(t0 - (t0 := datetime.now())))
package_path = 'pipeline.pt2'


@spaces.GPU(duration=1500)
def compile_transformer():

    def _example_tensor(*shape):
        return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)

    is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
    seq_length = 256 if is_timestep_distilled else 512

    transformer_kwargs = {
        'hidden_states': _example_tensor(1, 4096, 64),
        'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
        'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
        'pooled_projections': _example_tensor(1, 768),
        'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
        'txt_ids': _example_tensor(seq_length, 3),
        'img_ids': _example_tensor(4096, 3),
        'joint_attention_kwargs': {},
        'return_dict': False,
    }

    inductor_configs = {
        'conv_1x1_as_mm': True,
        'epilogue_fusion': False,
        'coordinate_descent_tuning': True,
        'coordinate_descent_check_all_directions': True,
        'max_autotune': True,
        'triton.cudagraphs': True,
    }

    exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)

    return aoti_compile(exported, inductor_configs)


transformer_config = pipeline.transformer.config
pipeline.transformer = compile_transformer()
pipeline.transformer.config = transformer_config

@spaces.GPU
def _generate_image(prompt: str, t0: datetime):
    print('@spaces.GPU', -(t0 - (t0 := datetime.now())))
    images = []
    for _ in range(4):
        images += pipeline(prompt, num_inference_steps=4).images
        print('pipeline', -(t0 - (t0 := datetime.now())))
    return images


def generate_image(prompt: str, progress=gr.Progress(track_tqdm=True)):
    return _generate_image(prompt, datetime.now())

gr.Interface(generate_image, gr.Text(), gr.Gallery()).launch(show_error=True)