import torch from PIL import Image from diffusers import DDIMScheduler from accelerate.utils import set_seed from torchvision.transforms.functional import to_pil_image, to_tensor, resize from pipeline_sd import ADPipeline from pipeline_sdxl import ADPipeline as ADXLPipeline from utils import Controller import os import spaces class Runner: def __init__(self): self.sd15 = None self.sdxl = None self.loss_fn = torch.nn.L1Loss(reduction="mean") def load_pipeline(self, model_path_or_name): if 'xl' in model_path_or_name and self.sdxl is None: scheduler = DDIMScheduler.from_pretrained(model_path_or_name, subfolder="scheduler") self.sdxl = ADXLPipeline.from_pretrained(model_path_or_name, scheduler=scheduler, safety_checker=None) self.sdxl.classifier = self.sdxl.unet elif self.sd15 is None: scheduler = DDIMScheduler.from_pretrained(model_path_or_name, subfolder="scheduler") self.sd15 = ADPipeline.from_pretrained(model_path_or_name, scheduler=scheduler, safety_checker=None) self.sd15.classifier = self.sd15.unet def preprocecss(self, image: Image.Image, height=None, width=None): image = resize(image, size=512) if width is None or height is None: width, height = image.size new_width = (width // 64) * 64 new_height = (height // 64) * 64 size = (new_width, new_height) image = image.resize(size, Image.BICUBIC) return to_tensor(image).unsqueeze(0) @spaces.GPU def run_style_transfer(self, content_image, style_image, seed, num_steps, lr, content_weight, mixed_precision, model, **kwargs): self.load_pipeline(model) content_image = self.preprocecss(content_image) style_image = self.preprocecss(style_image, height=512, width=512) height, width = content_image.shape[-2:] set_seed(seed) controller = Controller(self_layers=(10, 16)) result = self.sd15.optimize( lr=lr, batch_size=1, iters=1, width=width, height=height, weight=content_weight, controller=controller, style_image=style_image, content_image=content_image, mixed_precision=mixed_precision, num_inference_steps=num_steps, enable_gradient_checkpoint=False, ) output_image = to_pil_image(result[0].float()) del result torch.cuda.empty_cache() return [output_image] @spaces.GPU def run_style_t2i_generation(self, style_image, prompt, negative_prompt, guidance_scale, height, width, seed, num_steps, iterations, lr, num_images_per_prompt, mixed_precision, is_adain, model): self.load_pipeline(model) use_xl = 'xl' in model height, width = (1024, 1024) if 'xl' in model else (512, 512) style_image = self.preprocecss(style_image, height=height, width=width) set_seed(seed) self_layers = (64, 70) if use_xl else (10, 16) controller = Controller(self_layers=self_layers) pipeline = self.sdxl if use_xl else self.sd15 images = pipeline.sample( controller=controller, iters=iterations, lr=lr, adain=is_adain, height=height, width=width, mixed_precision=mixed_precision, style_image=style_image, prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_steps, num_images_per_prompt=num_images_per_prompt, enable_gradient_checkpoint=False ) output_images = [to_pil_image(image.float()) for image in images] del images torch.cuda.empty_cache() return output_images @spaces.GPU def run_texture_synthesis(self, texture_image, height, width, seed, num_steps, iterations, lr, mixed_precision, num_images_per_prompt, synthesis_way,model): self.load_pipeline(model) texture_image = self.preprocecss(texture_image, height=512, width=512) set_seed(seed) controller = Controller(self_layers=(10, 16)) if synthesis_way == 'Sampling': results = self.sd15.sample( lr=lr, adain=False, iters=iterations, width=width, height=height, weight=0., controller=controller, style_image=texture_image, content_image=None, prompt="", negative_prompt="", mixed_precision=mixed_precision, num_inference_steps=num_steps, guidance_scale=1., num_images_per_prompt=num_images_per_prompt, enable_gradient_checkpoint=False, ) elif synthesis_way == 'MultiDiffusion': results = self.sd15.panorama( lr=lr, iters=iterations, width=width, height=height, weight=0., controller=controller, style_image=texture_image, content_image=None, prompt="", negative_prompt="", stride=8, view_batch_size=8, mixed_precision=mixed_precision, num_inference_steps=num_steps, guidance_scale=1., num_images_per_prompt=num_images_per_prompt, enable_gradient_checkpoint=False, ) else: raise ValueError output_images = [to_pil_image(image.float()) for image in results] del results torch.cuda.empty_cache() return output_images