|
|
|
''' |
|
Copyright 2022 The International Digital Economy Academy (IDEA). CCNL team. All rights reserved. |
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
you may not use this file except in compliance with the License. |
|
You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
@File : train.py |
|
@Time : 2022/11/09 22:27 |
|
@Author : Gan Ruyi |
|
@Version : 1.0 |
|
@Contact : [email protected] |
|
@License : (C)Copyright 2022-2023, CCNL-IDEA |
|
''' |
|
import hashlib |
|
import itertools |
|
import os |
|
from pathlib import Path |
|
from tqdm.auto import tqdm |
|
import torch |
|
import argparse |
|
from pytorch_lightning import ( |
|
LightningModule, |
|
Trainer, |
|
) |
|
from pytorch_lightning.callbacks import ( |
|
LearningRateMonitor, |
|
) |
|
from transformers import BertTokenizer, BertModel, CLIPTokenizer, CLIPTextModel |
|
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
|
from torch.nn import functional as F |
|
from fengshen.data.dreambooth_datasets.dreambooth_datasets import PromptDataset, DreamBoothDataset |
|
from fengshen.data.universal_datamodule import UniversalDataModule |
|
from fengshen.models.model_utils import ( |
|
add_module_args, |
|
configure_optimizers, |
|
get_total_steps, |
|
) |
|
from fengshen.utils.universal_checkpoint import UniversalCheckpoint |
|
from fengshen.data.dreambooth_datasets.dreambooth_datasets import add_data_args |
|
|
|
|
|
class StableDiffusionDreamBooth(LightningModule): |
|
@staticmethod |
|
def add_module_specific_args(parent_parser): |
|
parser = parent_parser.add_argument_group('Taiyi Stable Diffusion Module') |
|
parser.add_argument('--train_text_encoder', action='store_true', default=False) |
|
|
|
parser.add_argument('--train_unet', action='store_true', default=True) |
|
return parent_parser |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
if 'Taiyi-Stable-Diffusion-1B-Chinese-v0.1' in args.model_path: |
|
self.tokenizer = BertTokenizer.from_pretrained( |
|
args.model_path, subfolder="tokenizer") |
|
self.text_encoder = BertModel.from_pretrained( |
|
args.model_path, subfolder="text_encoder") |
|
else: |
|
self.tokenizer = CLIPTokenizer.from_pretrained( |
|
args.model_path, subfolder="tokenizer") |
|
self.text_encoder = CLIPTextModel.from_pretrained( |
|
args.model_path, subfolder="text_encoder") |
|
self.vae = AutoencoderKL.from_pretrained( |
|
args.model_path, subfolder="vae") |
|
self.unet = UNet2DConditionModel.from_pretrained( |
|
args.model_path, subfolder="unet") |
|
self.noise_scheduler = DDPMScheduler.from_config( |
|
args.model_path, subfolder="scheduler") |
|
|
|
|
|
self.vae.requires_grad_(False) |
|
if not args.train_text_encoder: |
|
self.requires_grad_(False) |
|
if not args.train_unet: |
|
self.requires_grad_(False) |
|
|
|
self.save_hyperparameters(args) |
|
|
|
def generate_extra_data(self): |
|
global_rank = self.global_rank |
|
device = self.trainer.device_ids[global_rank] |
|
print('generate on device {} of global_rank {}'.format(device, global_rank)) |
|
class_images_dir = Path(self.hparams.class_data_dir) |
|
if not class_images_dir.exists(): |
|
class_images_dir.mkdir(parents=True) |
|
cur_class_images = len(list(class_images_dir.iterdir())) |
|
|
|
if cur_class_images < self.hparams.num_class_images: |
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
self.hparams.model_path, |
|
safety_checker=None, |
|
) |
|
pipeline.set_progress_bar_config(disable=True) |
|
|
|
num_new_images = self.hparams.num_class_images - cur_class_images |
|
print(f"Number of class images to sample: {num_new_images}.") |
|
|
|
sample_dataset = PromptDataset(self.hparams.class_prompt, num_new_images) |
|
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=self.hparams.sample_batch_size) |
|
|
|
pipeline.to(device) |
|
|
|
for example in tqdm( |
|
sample_dataloader, desc="Generating class images", disable=global_rank != 0 |
|
): |
|
images = pipeline(example["prompt"]).images |
|
|
|
for i, image in enumerate(images): |
|
hash_image = hashlib.sha1(image.tobytes()).hexdigest() |
|
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" |
|
image.save(image_filename) |
|
|
|
del pipeline |
|
|
|
|
|
|
|
def setup(self, stage) -> None: |
|
if self.hparams.with_prior_preservation: |
|
self.generate_extra_data() |
|
if stage == 'fit': |
|
self.total_steps = get_total_steps(self.trainer, self.hparams) |
|
print('Total steps: {}' .format(self.total_steps)) |
|
|
|
def configure_optimizers(self): |
|
model_params = [] |
|
if self.hparams.train_unet and self.hparams.train_text_encoder: |
|
model_params = itertools.chain(self.unet.parameters(), self.text_encoder.parameters()) |
|
elif self.hparams.train_unet: |
|
model_params = self.unet.parameters() |
|
elif self.hparams.train_text_encoder: |
|
model_params = self.text_encoder.parameters() |
|
return configure_optimizers(self, model_params=model_params) |
|
|
|
def training_step(self, batch, batch_idx): |
|
if self.hparams.train_text_encoder: |
|
self.text_encoder.train() |
|
if self.hparams.train_unet: |
|
self.unet.train() |
|
|
|
latents = self.vae.encode(batch["pixel_values"]).latent_dist.sample() |
|
latents = latents * 0.18215 |
|
|
|
|
|
noise = torch.randn(latents.shape).to(latents.device) |
|
noise = noise.to(dtype=self.unet.dtype) |
|
bsz = latents.shape[0] |
|
|
|
timesteps = torch.randint( |
|
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) |
|
timesteps = timesteps.long() |
|
|
|
|
|
|
|
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) |
|
noisy_latents = noisy_latents.to(dtype=self.unet.dtype) |
|
|
|
|
|
|
|
encoder_hidden_states = self.text_encoder(batch["input_ids"])[0] |
|
|
|
|
|
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample |
|
|
|
if self.hparams.with_prior_preservation: |
|
|
|
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
|
noise, noise_prior = torch.chunk(noise, 2, dim=0) |
|
|
|
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
|
|
|
prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="mean") |
|
|
|
loss = loss + args.prior_loss_weight * prior_loss |
|
else: |
|
loss = F.mse_loss(noise_pred, noise, reduction="mean") |
|
self.log("train_loss", loss.item(), on_epoch=False, prog_bar=True, logger=True) |
|
|
|
if self.trainer.global_rank == 0: |
|
if (self.global_step+1) % 5000 == 0: |
|
print('saving model...') |
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
args.model_path, unet=self.unet, text_encoder=self.text_encoder, tokenizer=self.tokenizer, |
|
) |
|
pipeline.save_pretrained(os.path.join( |
|
args.default_root_dir, f'hf_out_{self.trainer.current_epoch}')) |
|
|
|
return {"loss": loss} |
|
|
|
def on_train_end(self) -> None: |
|
if self.trainer.global_rank == 0: |
|
print('saving model...') |
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
args.model_path, unet=self.unet, text_encoder=self.text_encoder, tokenizer=self.tokenizer, |
|
) |
|
pipeline.save_pretrained(os.path.join( |
|
args.default_root_dir, f'hf_out_{self.trainer.current_epoch}')) |
|
|
|
def on_load_checkpoint(self, checkpoint) -> None: |
|
|
|
global_step_offset = checkpoint["global_step"] |
|
if 'global_samples' in checkpoint: |
|
self.consumed_samples = checkpoint['global_samples'] |
|
self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset |
|
|
|
|
|
if __name__ == '__main__': |
|
args_parser = argparse.ArgumentParser() |
|
args_parser = add_module_args(args_parser) |
|
args_parser = add_data_args(args_parser) |
|
args_parser = UniversalDataModule.add_data_specific_args(args_parser) |
|
args_parser = Trainer.add_argparse_args(args_parser) |
|
args_parser = StableDiffusionDreamBooth.add_module_specific_args(args_parser) |
|
args_parser = UniversalCheckpoint.add_argparse_args(args_parser) |
|
args = args_parser.parse_args() |
|
|
|
model = StableDiffusionDreamBooth(args) |
|
|
|
tokenizer = model.tokenizer |
|
datasets = DreamBoothDataset( |
|
instance_data_dir=args.instance_data_dir, |
|
instance_prompt=args.instance_prompt, |
|
tokenizer=tokenizer, |
|
class_data_dir=args.class_data_dir, |
|
class_prompt=args.class_prompt, |
|
size=512, |
|
center_crop=args.center_crop, |
|
) |
|
|
|
datasets = {'train': datasets} |
|
|
|
def collate_fn(examples): |
|
|
|
input_ids = [example["instance_prompt_ids"] for example in examples] |
|
pixel_values = [example["instance_images"] for example in examples] |
|
|
|
|
|
|
|
if args.with_prior_preservation: |
|
input_ids += [example["class_prompt_ids"] for example in examples] |
|
pixel_values += [example["class_images"] for example in examples] |
|
|
|
pixel_values = torch.stack(pixel_values) |
|
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() |
|
|
|
input_ids = tokenizer.pad( |
|
{"input_ids": input_ids}, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
return_tensors="pt", |
|
).input_ids |
|
|
|
batch = { |
|
"input_ids": input_ids, |
|
"pixel_values": pixel_values, |
|
} |
|
|
|
return batch |
|
|
|
datamodule = UniversalDataModule( |
|
tokenizer=tokenizer, collate_fn=collate_fn, args=args, datasets=datasets) |
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
checkpoint_callback = UniversalCheckpoint(args) |
|
|
|
trainer = Trainer.from_argparse_args(args, |
|
callbacks=[ |
|
lr_monitor, |
|
checkpoint_callback]) |
|
|
|
trainer.fit(model, datamodule, ckpt_path=args.load_ckpt_path) |
|
|