|
import os |
|
import torch |
|
import argparse |
|
from pytorch_lightning import ( |
|
LightningModule, |
|
Trainer, |
|
) |
|
from pytorch_lightning.callbacks import ( |
|
LearningRateMonitor, |
|
) |
|
from fengshen.data.universal_datamodule import UniversalDataModule |
|
from fengshen.models.model_utils import ( |
|
add_module_args, |
|
configure_optimizers, |
|
get_total_steps, |
|
) |
|
from fengshen.utils.universal_checkpoint import UniversalCheckpoint |
|
from transformers import BertTokenizer, BertModel |
|
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
|
from torch.nn import functional as F |
|
from fengshen.data.taiyi_stable_diffusion_datasets.taiyi_datasets import add_data_args, load_data |
|
from torchvision import transforms |
|
from PIL import Image |
|
from torch.utils.data._utils.collate import default_collate |
|
|
|
|
|
class Collator(): |
|
def __init__(self, args, tokenizer): |
|
self.image_transforms = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), |
|
transforms.CenterCrop( |
|
args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]), |
|
] |
|
) |
|
self.tokenizer = tokenizer |
|
|
|
def __call__(self, inputs): |
|
examples = [] |
|
max_length = min(max([len(i['caption']) for i in inputs]), 512) |
|
for i in inputs: |
|
example = {} |
|
instance_image = Image.open(i['img_path']) |
|
if not instance_image.mode == "RGB": |
|
instance_image = instance_image.convert("RGB") |
|
example["pixel_values"] = self.image_transforms(instance_image) |
|
example["input_ids"] = self.tokenizer( |
|
i['caption'], |
|
padding="max_length", |
|
truncation=True, |
|
max_length=max_length, |
|
return_tensors='pt', |
|
)['input_ids'][0] |
|
examples.append(example) |
|
return default_collate(examples) |
|
|
|
|
|
class StableDiffusion(LightningModule): |
|
@staticmethod |
|
def add_module_specific_args(parent_parser): |
|
parser = parent_parser.add_argument_group('Taiyi Stable Diffusion Module') |
|
parser.add_argument('--freeze_unet', action='store_true', default=False) |
|
parser.add_argument('--freeze_text_encoder', action='store_true', default=False) |
|
return parent_parser |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
self.tokenizer = BertTokenizer.from_pretrained( |
|
args.model_path, subfolder="tokenizer") |
|
self.text_encoder = BertModel.from_pretrained( |
|
args.model_path, subfolder="text_encoder") |
|
self.vae = AutoencoderKL.from_pretrained( |
|
args.model_path, subfolder="vae") |
|
self.unet = UNet2DConditionModel.from_pretrained( |
|
args.model_path, subfolder="unet") |
|
|
|
self.unet.set_use_memory_efficient_attention_xformers(False) |
|
|
|
self.noise_scheduler = DDPMScheduler( |
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 |
|
) |
|
|
|
for param in self.vae.parameters(): |
|
param.requires_grad = False |
|
|
|
if args.freeze_text_encoder: |
|
for param in self.text_encoder.parameters(): |
|
param.requires_grad = False |
|
|
|
if args.freeze_unet: |
|
for param in self.unet.parameters(): |
|
param.requires_grad = False |
|
|
|
self.save_hyperparameters(args) |
|
|
|
def setup(self, stage) -> None: |
|
if stage == 'fit': |
|
self.total_steps = get_total_steps(self.trainer, self.hparams) |
|
print('Total steps: {}' .format(self.total_steps)) |
|
|
|
def configure_optimizers(self): |
|
return configure_optimizers(self) |
|
|
|
def training_step(self, batch, batch_idx): |
|
self.text_encoder.train() |
|
|
|
latents = self.vae.encode(batch["pixel_values"]).latent_dist.sample() |
|
latents = latents * 0.18215 |
|
|
|
|
|
noise = torch.randn(latents.shape).to(latents.device) |
|
noise = noise.to(dtype=self.unet.dtype) |
|
bsz = latents.shape[0] |
|
|
|
timesteps = torch.randint( |
|
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) |
|
timesteps = timesteps.long() |
|
|
|
|
|
|
|
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) |
|
noisy_latents = noisy_latents.to(dtype=self.unet.dtype) |
|
|
|
|
|
encoder_hidden_states = self.text_encoder(batch["input_ids"])[0] |
|
|
|
|
|
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample |
|
|
|
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
|
self.log("train_loss", loss.item(), on_epoch=False, prog_bar=True, logger=True) |
|
|
|
if self.trainer.global_rank == 0 and self.global_step == 100: |
|
|
|
from fengshen.utils.utils import report_memory |
|
report_memory('stable diffusion') |
|
|
|
return {"loss": loss} |
|
|
|
def on_save_checkpoint(self, checkpoint) -> None: |
|
if self.trainer.global_rank == 0: |
|
print('saving model...') |
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
self.hparams.model_path, |
|
text_encoder=self.text_encoder, |
|
tokenizer=self.tokenizer, |
|
unet=self.unet) |
|
self.trainer.current_epoch |
|
pipeline.save_pretrained(os.path.join( |
|
args.default_root_dir, f'hf_out_{self.trainer.current_epoch}_{self.trainer.global_step}')) |
|
|
|
def on_load_checkpoint(self, checkpoint) -> None: |
|
|
|
global_step_offset = checkpoint["global_step"] |
|
if 'global_samples' in checkpoint: |
|
self.consumed_samples = checkpoint['global_samples'] |
|
self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset |
|
|
|
|
|
if __name__ == '__main__': |
|
args_parser = argparse.ArgumentParser() |
|
args_parser = add_module_args(args_parser) |
|
args_parser = add_data_args(args_parser) |
|
args_parser = UniversalDataModule.add_data_specific_args(args_parser) |
|
args_parser = Trainer.add_argparse_args(args_parser) |
|
args_parser = StableDiffusion.add_module_specific_args(args_parser) |
|
args_parser = UniversalCheckpoint.add_argparse_args(args_parser) |
|
args = args_parser.parse_args() |
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
checkpoint_callback = UniversalCheckpoint(args) |
|
trainer = Trainer.from_argparse_args(args, |
|
callbacks=[ |
|
lr_monitor, |
|
checkpoint_callback]) |
|
|
|
model = StableDiffusion(args) |
|
tokenizer = model.tokenizer |
|
datasets = load_data(args, global_rank=trainer.global_rank) |
|
collate_fn = Collator(args, tokenizer) |
|
|
|
datamoule = UniversalDataModule( |
|
tokenizer=tokenizer, collate_fn=collate_fn, args=args, datasets=datasets) |
|
|
|
trainer.fit(model, datamoule, ckpt_path=args.load_ckpt_path) |
|
|