ccchenzc's picture
Code Improvement.
0db8171
import torch
from PIL import Image
from diffusers import DDIMScheduler
from accelerate.utils import set_seed
from torchvision.transforms.functional import to_pil_image, to_tensor, resize
from pipeline_sd import ADPipeline
from pipeline_sdxl import ADPipeline as ADXLPipeline
from utils import Controller
import os
import spaces
class Runner:
def __init__(self):
self.sd15 = None
self.sdxl = None
self.loss_fn = torch.nn.L1Loss(reduction="mean")
def load_pipeline(self, model_path_or_name):
if 'xl' in model_path_or_name and self.sdxl is None:
scheduler = DDIMScheduler.from_pretrained(model_path_or_name, subfolder="scheduler")
self.sdxl = ADXLPipeline.from_pretrained(model_path_or_name, scheduler=scheduler, safety_checker=None)
self.sdxl.classifier = self.sdxl.unet
elif self.sd15 is None:
scheduler = DDIMScheduler.from_pretrained(model_path_or_name, subfolder="scheduler")
self.sd15 = ADPipeline.from_pretrained(model_path_or_name, scheduler=scheduler, safety_checker=None)
self.sd15.classifier = self.sd15.unet
def preprocecss(self, image: Image.Image, height=None, width=None):
image = resize(image, size=512)
if width is None or height is None:
width, height = image.size
new_width = (width // 64) * 64
new_height = (height // 64) * 64
size = (new_width, new_height)
image = image.resize(size, Image.BICUBIC)
return to_tensor(image).unsqueeze(0)
@spaces.GPU
def run_style_transfer(self, content_image, style_image, seed, num_steps, lr, content_weight, mixed_precision, model, **kwargs):
self.load_pipeline(model)
content_image = self.preprocecss(content_image)
style_image = self.preprocecss(style_image, height=512, width=512)
height, width = content_image.shape[-2:]
set_seed(seed)
controller = Controller(self_layers=(10, 16))
result = self.sd15.optimize(
lr=lr,
batch_size=1,
iters=1,
width=width,
height=height,
weight=content_weight,
controller=controller,
style_image=style_image,
content_image=content_image,
mixed_precision=mixed_precision,
num_inference_steps=num_steps,
enable_gradient_checkpoint=False,
)
output_image = to_pil_image(result[0].float())
del result
torch.cuda.empty_cache()
return [output_image]
@spaces.GPU
def run_style_t2i_generation(self, style_image, prompt, negative_prompt, guidance_scale, height, width, seed, num_steps, iterations, lr, num_images_per_prompt, mixed_precision, is_adain, model):
self.load_pipeline(model)
use_xl = 'xl' in model
height, width = (1024, 1024) if 'xl' in model else (512, 512)
style_image = self.preprocecss(style_image, height=height, width=width)
set_seed(seed)
self_layers = (64, 70) if use_xl else (10, 16)
controller = Controller(self_layers=self_layers)
pipeline = self.sdxl if use_xl else self.sd15
images = pipeline.sample(
controller=controller,
iters=iterations,
lr=lr,
adain=is_adain,
height=height,
width=width,
mixed_precision=mixed_precision,
style_image=style_image,
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_steps,
num_images_per_prompt=num_images_per_prompt,
enable_gradient_checkpoint=False
)
output_images = [to_pil_image(image.float()) for image in images]
del images
torch.cuda.empty_cache()
return output_images
@spaces.GPU
def run_texture_synthesis(self, texture_image, height, width, seed, num_steps, iterations, lr, mixed_precision, num_images_per_prompt, synthesis_way,model):
self.load_pipeline(model)
texture_image = self.preprocecss(texture_image, height=512, width=512)
set_seed(seed)
controller = Controller(self_layers=(10, 16))
if synthesis_way == 'Sampling':
results = self.sd15.sample(
lr=lr,
adain=False,
iters=iterations,
width=width,
height=height,
weight=0.,
controller=controller,
style_image=texture_image,
content_image=None,
prompt="",
negative_prompt="",
mixed_precision=mixed_precision,
num_inference_steps=num_steps,
guidance_scale=1.,
num_images_per_prompt=num_images_per_prompt,
enable_gradient_checkpoint=False,
)
elif synthesis_way == 'MultiDiffusion':
results = self.sd15.panorama(
lr=lr,
iters=iterations,
width=width,
height=height,
weight=0.,
controller=controller,
style_image=texture_image,
content_image=None,
prompt="",
negative_prompt="",
stride=8,
view_batch_size=8,
mixed_precision=mixed_precision,
num_inference_steps=num_steps,
guidance_scale=1.,
num_images_per_prompt=num_images_per_prompt,
enable_gradient_checkpoint=False,
)
else:
raise ValueError
output_images = [to_pil_image(image.float()) for image in results]
del results
torch.cuda.empty_cache()
return output_images