|
from datasets import load_dataset |
|
from diffusers import StableDiffusionPipeline |
|
import torch |
|
import os |
|
import json |
|
from PIL import Image |
|
import argparse |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Generate images from I2P dataset") |
|
|
|
parser.add_argument("--output_dir", type=str, help="Output directory") |
|
parser.add_argument("--model_path", type=str, help="Path to model checkpoint", default="CompVis/stable-diffusion-v1-4") |
|
parser.add_argument("--seed", type=int, help="Seed for random number generator", default=0) |
|
parser.add_argument("--prompt", type=str, help="Prompt for image generation") |
|
parser.add_argument("--mode", type=str, help="Mode for image generation", choices=["train","test"], default="train") |
|
parser.add_argument("--num_train_images", type=int, help="Number of images to generate for training", default=1000) |
|
args = parser.parse_args() |
|
return args |
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
if(args.mode == "test"): |
|
|
|
metadata = [] |
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained(args.model_path, safety_checker=None, torch_dtype=torch.float16).to(device) |
|
|
|
for i in range(args.num_train_images): |
|
while True: |
|
output = pipe(prompt=args.prompt) |
|
image = output.images[0] |
|
nsfw = output.nsfw_content_detected |
|
|
|
if isinstance(nsfw, list): |
|
nsfw = nsfw[0] |
|
|
|
if not nsfw: |
|
break |
|
image.save(os.path.join(args.output_dir, f"{args.prompt}_{i}.png")) |
|
|
|
|
|
elif(args.mode == "train"): |
|
metadata = [] |
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
os.makedirs(os.path.join(args.output_dir, "train"), exist_ok=True) |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained(args.model_path, safety_checker=None, torch_dtype=torch.float16).to(device) |
|
|
|
for i in range(args.num_train_images): |
|
while True: |
|
output = pipe(prompt=args.prompt) |
|
image = output.images[0] |
|
nsfw = output.nsfw_content_detected |
|
|
|
if isinstance(nsfw, list): |
|
nsfw = nsfw[0] |
|
|
|
if not nsfw: |
|
break |
|
image.save(os.path.join(args.output_dir, "train", f"{args.prompt}_{i}.png")) |
|
|
|
metadata.append({"file_name": f"train/{args.prompt}_{i}.png", "text": args.prompt}) |
|
|
|
with open(os.path.join(args.output_dir, 'metadata.jsonl'), 'w') as f: |
|
for m in metadata: |
|
f.write(json.dumps(m) + "\n") |
|
|
|
|
|
|