File size: 7,682 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import os
import torch
import argparse
from pytorch_lightning import (
LightningModule,
Trainer,
)
from pytorch_lightning.callbacks import (
LearningRateMonitor,
)
from fengshen.data.universal_datamodule import UniversalDataModule
from fengshen.models.model_utils import (
add_module_args,
configure_optimizers,
get_total_steps,
)
from fengshen.utils.universal_checkpoint import UniversalCheckpoint
from transformers import BertTokenizer, BertModel
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from torch.nn import functional as F
from fengshen.data.taiyi_stable_diffusion_datasets.taiyi_datasets import add_data_args, load_data
from torchvision import transforms
from PIL import Image
from torch.utils.data._utils.collate import default_collate
class Collator():
def __init__(self, args, tokenizer):
self.image_transforms = transforms.Compose(
[
transforms.Resize(
args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(
args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.tokenizer = tokenizer
def __call__(self, inputs):
examples = []
max_length = min(max([len(i['caption']) for i in inputs]), 512)
for i in inputs:
example = {}
instance_image = Image.open(i['img_path'])
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["pixel_values"] = self.image_transforms(instance_image)
example["input_ids"] = self.tokenizer(
i['caption'],
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors='pt',
)['input_ids'][0]
examples.append(example)
return default_collate(examples)
class StableDiffusion(LightningModule):
@staticmethod
def add_module_specific_args(parent_parser):
parser = parent_parser.add_argument_group('Taiyi Stable Diffusion Module')
parser.add_argument('--freeze_unet', action='store_true', default=False)
parser.add_argument('--freeze_text_encoder', action='store_true', default=False)
return parent_parser
def __init__(self, args):
super().__init__()
self.tokenizer = BertTokenizer.from_pretrained(
args.model_path, subfolder="tokenizer")
self.text_encoder = BertModel.from_pretrained(
args.model_path, subfolder="text_encoder") # load from taiyi_finetune-v0
self.vae = AutoencoderKL.from_pretrained(
args.model_path, subfolder="vae")
self.unet = UNet2DConditionModel.from_pretrained(
args.model_path, subfolder="unet")
# TODO: 使用xformers配合deepspeed速度反而有下降(待确认
self.unet.set_use_memory_efficient_attention_xformers(False)
self.noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
for param in self.vae.parameters():
param.requires_grad = False
if args.freeze_text_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
if args.freeze_unet:
for param in self.unet.parameters():
param.requires_grad = False
self.save_hyperparameters(args)
def setup(self, stage) -> None:
if stage == 'fit':
self.total_steps = get_total_steps(self.trainer, self.hparams)
print('Total steps: {}' .format(self.total_steps))
def configure_optimizers(self):
return configure_optimizers(self)
def training_step(self, batch, batch_idx):
self.text_encoder.train()
latents = self.vae.encode(batch["pixel_values"]).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
noise = noise.to(dtype=self.unet.dtype)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
noisy_latents = noisy_latents.to(dtype=self.unet.dtype)
# Get the text embedding for conditioning
encoder_hidden_states = self.text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
self.log("train_loss", loss.item(), on_epoch=False, prog_bar=True, logger=True)
if self.trainer.global_rank == 0 and self.global_step == 100:
# 打印显存占用
from fengshen.utils.utils import report_memory
report_memory('stable diffusion')
return {"loss": loss}
def on_save_checkpoint(self, checkpoint) -> None:
if self.trainer.global_rank == 0:
print('saving model...')
pipeline = StableDiffusionPipeline.from_pretrained(
self.hparams.model_path,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
unet=self.unet)
self.trainer.current_epoch
pipeline.save_pretrained(os.path.join(
args.default_root_dir, f'hf_out_{self.trainer.current_epoch}_{self.trainer.global_step}'))
def on_load_checkpoint(self, checkpoint) -> None:
# 兼容低版本lightning,低版本lightning从ckpt起来时steps数会被重置为0
global_step_offset = checkpoint["global_step"]
if 'global_samples' in checkpoint:
self.consumed_samples = checkpoint['global_samples']
self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset
if __name__ == '__main__':
args_parser = argparse.ArgumentParser()
args_parser = add_module_args(args_parser)
args_parser = add_data_args(args_parser)
args_parser = UniversalDataModule.add_data_specific_args(args_parser)
args_parser = Trainer.add_argparse_args(args_parser)
args_parser = StableDiffusion.add_module_specific_args(args_parser)
args_parser = UniversalCheckpoint.add_argparse_args(args_parser)
args = args_parser.parse_args()
lr_monitor = LearningRateMonitor(logging_interval='step')
checkpoint_callback = UniversalCheckpoint(args)
trainer = Trainer.from_argparse_args(args,
callbacks=[
lr_monitor,
checkpoint_callback])
model = StableDiffusion(args)
tokenizer = model.tokenizer
datasets = load_data(args, global_rank=trainer.global_rank)
collate_fn = Collator(args, tokenizer)
datamoule = UniversalDataModule(
tokenizer=tokenizer, collate_fn=collate_fn, args=args, datasets=datasets)
trainer.fit(model, datamoule, ckpt_path=args.load_ckpt_path)
|