hjc-owo
init repo
966ae59
'''
@File : ReFL.py
@Time : 2023/05/01 19:36:00
@Auther : Jiazheng Xu
@Contact : [email protected]
@Description: ReFL Algorithm.
* Based on diffusers code base
* https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
'''
import argparse
import logging
import math
import os
import random
from pathlib import Path
import accelerate
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
import ImageReward as RM
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate
from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.16.0.dev0")
logger = get_logger(__name__, log_level="INFO")
DATASET_NAME_MAPPING = {
"refl": ("image", "text"),
}
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--grad_scale", type=float, default=1e-3, help="Scale divided for grad loss value."
)
parser.add_argument(
"--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--validation_prompts",
type=str,
default=None,
nargs="+",
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
)
parser.add_argument(
"--output_dir",
type=str,
default="checkpoint/refl",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=2, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=100,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=4,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument(
"--non_ema_revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
" remote repository specified with --pretrained_model_name_or_path."
),
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--checkpointing_steps",
type=int,
default=100,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--validation_epochs",
type=int,
default=5,
help="Run validation every X epochs.",
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="text2image-refl",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# default to using the same revision for the non-ema model if not specified
if args.non_ema_revision is None:
args.non_ema_revision = args.revision
return args
class Trainer(object):
def __init__(self, pretrained_model_name_or_path, train_data_dir, args):
self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.train_data_dir = train_data_dir
# Sanity checks
if args.dataset_name is None and self.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.")
if args.non_ema_revision is not None:
deprecate(
"non_ema_revision!=None",
"0.15.0",
message=(
"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
" use `--variant=non_ema` instead."
),
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
self.accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
logging_dir=logging_dir,
project_config=accelerator_project_config,
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(self.accelerator.state, main_process_only=False)
if self.accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if self.accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
self.repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load scheduler, tokenizer and models.
self.noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
self.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
self.text_encoder = CLIPTextModel.from_pretrained(
self.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
self.vae = AutoencoderKL.from_pretrained(self.pretrained_model_name_or_path, subfolder="vae",
revision=args.revision)
self.unet = UNet2DConditionModel.from_pretrained(
self.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
self.reward_model = RM.load("ImageReward-v1.0", device=self.accelerator.device)
# Freeze vae and text_encoder
self.vae.requires_grad_(False)
self.text_encoder.requires_grad_(False)
self.reward_model.requires_grad_(False)
# Create EMA for the unet.
if args.use_ema:
self.ema_unet = UNet2DConditionModel.from_pretrained(
self.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
self.ema_unet = EMAModel(self.ema_unet.parameters(), model_cls=UNet2DConditionModel,
model_config=self.ema_unet.config)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
self.unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `self.accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if args.use_ema:
self.ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
self.ema_unet.load_state_dict(load_model.state_dict())
self.ema_unet.to(self.accelerator.device)
del load_model
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
self.accelerator.register_save_state_pre_hook(save_model_hook)
self.accelerator.register_load_state_pre_hook(load_model_hook)
if args.gradient_checkpointing:
self.unet.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * self.accelerator.num_processes
)
# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
self.optimizer = optimizer_cls(
self.unet.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
else:
data_files = {}
data_files["train"] = self.train_data_dir
dataset = load_dataset(
"json",
data_files=data_files,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
# Get the column names for input/target.
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True,
return_tensors="pt"
)
return inputs.input_ids
def preprocess_train(examples):
examples["input_ids"] = tokenize_captions(examples)
examples["rm_input_ids"] = self.reward_model.blip.tokenizer(examples[caption_column], padding='max_length',
truncation=True, max_length=35,
return_tensors="pt").input_ids
examples["rm_attention_mask"] = self.reward_model.blip.tokenizer(examples[caption_column],
padding='max_length', truncation=True,
max_length=35,
return_tensors="pt").attention_mask
return examples
with self.accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
self.train_dataset = dataset["train"].with_transform(preprocess_train)
def collate_fn(examples):
input_ids = torch.stack([example["input_ids"] for example in examples])
rm_input_ids = torch.stack([example["rm_input_ids"] for example in examples])
rm_attention_mask = torch.stack([example["rm_attention_mask"] for example in examples])
input_ids = input_ids.view(-1, input_ids.shape[-1])
rm_input_ids = rm_input_ids.view(-1, rm_input_ids.shape[-1])
rm_attention_mask = rm_attention_mask.view(-1, rm_attention_mask.shape[-1])
return {"input_ids": input_ids, "rm_input_ids": rm_input_ids, "rm_attention_mask": rm_attention_mask}
# DataLoaders creation:
self.train_dataloader = torch.utils.data.DataLoader(
self.train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch
overrode_max_train_steps = True
self.lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=self.optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
# Prepare everything with our `self.accelerator`.
self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler
)
if args.use_ema:
self.ema_unet.to(self.accelerator.device)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
self.weight_dtype = torch.float32
if self.accelerator.mixed_precision == "fp16":
self.weight_dtype = torch.float16
elif self.accelerator.mixed_precision == "bf16":
self.weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu and cast to self.weight_dtype
self.text_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
self.vae.to(self.accelerator.device, dtype=self.weight_dtype)
self.reward_model.to(self.accelerator.device, dtype=self.weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / self.num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if self.accelerator.is_main_process:
tracker_config = dict(vars(args))
tracker_config.pop("validation_prompts")
self.accelerator.init_trackers(args.tracker_project_name, tracker_config)
def train(self, args):
# Train!
total_batch_size = args.train_batch_size * self.accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(self.train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
self.accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
self.accelerator.print(f"Resuming from checkpoint {path}")
self.accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // self.num_update_steps_per_epoch
resume_step = resume_global_step % (self.num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps),
disable=not self.accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs):
self.unet.train()
train_loss = 0.0
for step, batch in enumerate(self.train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with self.accelerator.accumulate(self.unet):
encoder_hidden_states = self.text_encoder(batch["input_ids"])[0]
latents = torch.randn((args.train_batch_size, 4, 64, 64), device=self.accelerator.device)
self.noise_scheduler.set_timesteps(40, device=self.accelerator.device)
timesteps = self.noise_scheduler.timesteps
mid_timestep = random.randint(30, 39)
for i, t in enumerate(timesteps[:mid_timestep]):
with torch.no_grad():
latent_model_input = latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=encoder_hidden_states,
).sample
latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
latent_model_input = latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input,
timesteps[mid_timestep])
noise_pred = self.unet(
latent_model_input,
timesteps[mid_timestep],
encoder_hidden_states=encoder_hidden_states,
).sample
pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[mid_timestep],
latents).pred_original_sample.to(self.weight_dtype)
pred_original_sample = 1 / self.vae.config.scaling_factor * pred_original_sample
image = self.vae.decode(pred_original_sample.to(self.weight_dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
# image encode
def _transform():
return Compose([
Resize(224, interpolation=BICUBIC),
CenterCrop(224),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
rm_preprocess = _transform()
image = rm_preprocess(image).to(self.accelerator.device)
rewards = self.reward_model.score_gard(batch["rm_input_ids"], batch["rm_attention_mask"], image)
loss = F.relu(-rewards + 2)
loss = loss.mean() * args.grad_scale
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = self.accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
self.accelerator.backward(loss)
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.unet.parameters(), args.max_grad_norm)
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
# Checks if the self.accelerator has performed an optimization step behind the scenes
if self.accelerator.sync_gradients:
if args.use_ema:
self.ema_unet.step(self.unet.parameters())
progress_bar.update(1)
global_step += 1
self.accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if global_step % args.checkpointing_steps == 0:
if self.accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
self.accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if self.accelerator.is_main_process:
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
self.ema_unet.store(self.unet.parameters())
self.ema_unet.copy_to(self.unet.parameters())
if args.use_ema:
# Switch back to the original UNet parameters.
self.ema_unet.restore(self.unet.parameters())
# Create the pipeline using the trained modules and save it.
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
self.unet = self.accelerator.unwrap_model(self.unet)
if args.use_ema:
self.ema_unet.copy_to(self.unet.parameters())
pipeline = StableDiffusionPipeline.from_pretrained(
self.pretrained_model_name_or_path,
text_encoder=self.text_encoder,
vae=self.vae,
unet=self.unet,
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
upload_folder(
repo_id=self.repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
self.accelerator.end_training()