kevinlu4588's picture
Training gradient ascent models and sanity checking car erasure
e225ed6
from datasets import load_dataset
from diffusers import StableDiffusionPipeline
import torch
import os
import json
from PIL import Image
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Generate images from I2P dataset")
parser.add_argument("--output_dir", type=str, help="Output directory")
parser.add_argument("--model_path", type=str, help="Path to model checkpoint", default="CompVis/stable-diffusion-v1-4")
parser.add_argument("--seed", type=int, help="Seed for random number generator", default=0)
parser.add_argument("--prompt", type=str, help="Prompt for image generation")
parser.add_argument("--mode", type=str, help="Mode for image generation", choices=["train","test"], default="train")
parser.add_argument("--num_train_images", type=int, help="Number of images to generate for training", default=1000)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if(args.mode == "test"):
metadata = []
os.makedirs(args.output_dir, exist_ok=True)
pipe = StableDiffusionPipeline.from_pretrained(args.model_path, safety_checker=None, torch_dtype=torch.float16).to(device)
for i in range(args.num_train_images):
while True:
output = pipe(prompt=args.prompt)
image = output.images[0]
nsfw = output.nsfw_content_detected
#check if nsfw is a list
if isinstance(nsfw, list):
nsfw = nsfw[0]
if not nsfw:
break
image.save(os.path.join(args.output_dir, f"{args.prompt}_{i}.png"))
elif(args.mode == "train"):
metadata = []
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "train"), exist_ok=True)
pipe = StableDiffusionPipeline.from_pretrained(args.model_path, safety_checker=None, torch_dtype=torch.float16).to(device)
for i in range(args.num_train_images):
while True:
output = pipe(prompt=args.prompt)
image = output.images[0]
nsfw = output.nsfw_content_detected
#check if nsfw is a list
if isinstance(nsfw, list):
nsfw = nsfw[0]
if not nsfw:
break
image.save(os.path.join(args.output_dir, "train", f"{args.prompt}_{i}.png"))
metadata.append({"file_name": f"train/{args.prompt}_{i}.png", "text": args.prompt})
with open(os.path.join(args.output_dir, 'metadata.jsonl'), 'w') as f:
for m in metadata:
f.write(json.dumps(m) + "\n")