File size: 5,943 Bytes
82a7709
 
 
 
0db8171
82a7709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0db8171
 
82a7709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31d28e5
82a7709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31d28e5
82a7709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31d28e5
82a7709
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch
from PIL import Image
from diffusers import DDIMScheduler
from accelerate.utils import set_seed
from torchvision.transforms.functional import to_pil_image, to_tensor, resize

from pipeline_sd import ADPipeline
from pipeline_sdxl import ADPipeline as ADXLPipeline
from utils import Controller

import os
import spaces


class Runner:
    def __init__(self):
        self.sd15 = None
        self.sdxl = None
        self.loss_fn = torch.nn.L1Loss(reduction="mean")
    
    def load_pipeline(self, model_path_or_name):

        if 'xl' in model_path_or_name and self.sdxl is None:
            scheduler = DDIMScheduler.from_pretrained(model_path_or_name, subfolder="scheduler")
            self.sdxl = ADXLPipeline.from_pretrained(model_path_or_name, scheduler=scheduler, safety_checker=None)
            self.sdxl.classifier = self.sdxl.unet
        elif self.sd15 is None:
            scheduler = DDIMScheduler.from_pretrained(model_path_or_name, subfolder="scheduler")
            self.sd15 = ADPipeline.from_pretrained(model_path_or_name, scheduler=scheduler, safety_checker=None)
            self.sd15.classifier = self.sd15.unet

    def preprocecss(self, image: Image.Image, height=None, width=None):
        image = resize(image, size=512)
        
        if width is None or height is None: 
            width, height = image.size
        new_width = (width // 64) * 64
        new_height = (height // 64) * 64
        size = (new_width, new_height)
        image = image.resize(size, Image.BICUBIC)
        return to_tensor(image).unsqueeze(0)

    @spaces.GPU
    def run_style_transfer(self, content_image, style_image, seed, num_steps, lr, content_weight, mixed_precision, model, **kwargs):
        self.load_pipeline(model)

        content_image = self.preprocecss(content_image)
        style_image = self.preprocecss(style_image, height=512, width=512)

        height, width = content_image.shape[-2:]
        set_seed(seed)
        controller = Controller(self_layers=(10, 16))
        result = self.sd15.optimize(
            lr=lr,
            batch_size=1,
            iters=1,
            width=width,
            height=height,
            weight=content_weight,
            controller=controller,
            style_image=style_image,
            content_image=content_image,
            mixed_precision=mixed_precision,
            num_inference_steps=num_steps,
            enable_gradient_checkpoint=False,
        )
        output_image = to_pil_image(result[0].float())
        del result
        torch.cuda.empty_cache()
        return [output_image]

    @spaces.GPU
    def run_style_t2i_generation(self, style_image, prompt, negative_prompt, guidance_scale, height, width, seed, num_steps, iterations, lr, num_images_per_prompt, mixed_precision, is_adain, model):
        self.load_pipeline(model)

        use_xl = 'xl' in model
        height, width = (1024, 1024) if 'xl' in model else (512, 512)
        style_image = self.preprocecss(style_image, height=height, width=width)

        set_seed(seed)
        self_layers = (64, 70) if use_xl else (10, 16)
        
        controller = Controller(self_layers=self_layers)

        pipeline = self.sdxl if use_xl else self.sd15
        images = pipeline.sample(
            controller=controller,
            iters=iterations,
            lr=lr,
            adain=is_adain,
            height=height,
            width=width,
            mixed_precision=mixed_precision,
            style_image=style_image,
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_steps,
            num_images_per_prompt=num_images_per_prompt,
            enable_gradient_checkpoint=False
        )
        output_images = [to_pil_image(image.float()) for image in images]

        del images
        torch.cuda.empty_cache()
        return output_images

    @spaces.GPU
    def run_texture_synthesis(self, texture_image, height, width, seed, num_steps, iterations, lr, mixed_precision, num_images_per_prompt, synthesis_way,model):
        self.load_pipeline(model) 

        texture_image = self.preprocecss(texture_image, height=512, width=512)

        set_seed(seed)
        controller = Controller(self_layers=(10, 16))

        if synthesis_way == 'Sampling':
            results = self.sd15.sample(
                lr=lr,
                adain=False,
                iters=iterations,
                width=width,
                height=height,
                weight=0.,
                controller=controller,
                style_image=texture_image,
                content_image=None,
                prompt="",
                negative_prompt="",
                mixed_precision=mixed_precision,
                num_inference_steps=num_steps,
                guidance_scale=1.,
                num_images_per_prompt=num_images_per_prompt,
                enable_gradient_checkpoint=False,
            )
        elif synthesis_way == 'MultiDiffusion':   
            results = self.sd15.panorama(
                lr=lr,
                iters=iterations,
                width=width,
                height=height,
                weight=0.,
                controller=controller,
                style_image=texture_image,
                content_image=None,
                prompt="",
                negative_prompt="",
                stride=8,
                view_batch_size=8,
                mixed_precision=mixed_precision,
                num_inference_steps=num_steps,
                guidance_scale=1.,
                num_images_per_prompt=num_images_per_prompt,
                enable_gradient_checkpoint=False,
            )
        else:
            raise ValueError
        
        output_images = [to_pil_image(image.float()) for image in results]
        del results
        torch.cuda.empty_cache()
        return output_images