Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# InstructDiffusion | |
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix) | |
# Modified by Tiankai Hang ([email protected]) | |
# -------------------------------------------------------- | |
from __future__ import annotations | |
import os | |
import json | |
import math | |
from pathlib import Path | |
from typing import Any | |
import numpy as np | |
import torch | |
import torchvision | |
from einops import rearrange | |
import PIL | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from tqdm.auto import tqdm | |
import random | |
from dataset.utils.zip_manager import MultipleZipManager | |
if hasattr(Image, "Resampling"): | |
# deprecated in pillow >= 10.0.0 | |
RESAMPLING_METHOD = Image.Resampling.LANCZOS | |
else: | |
RESAMPLING_METHOD = Image.LANCZOS | |
class FilteredIP2PDataset(Dataset): | |
def __init__( | |
self, | |
path: str, | |
split: str = "train", | |
splits: tuple[float, float, float] = (0.9, 0.05, 0.05), | |
min_resize_res: int = 256, | |
max_resize_res: int = 256, | |
crop_res: int = 256, | |
flip_prob: float = 0.0, | |
zip_start_index: int = 0, | |
zip_end_index: int = 30, | |
instruct: bool = False, | |
max_num_images = None, | |
sample_weight: float = 1.0, | |
reverse_version: bool = False, | |
**kwargs | |
): | |
assert split in ("train", "val", "test") | |
assert sum(splits) == 1 | |
self.path = path | |
self.min_resize_res = min_resize_res | |
self.max_resize_res = max_resize_res | |
self.crop_res = crop_res | |
self.flip_prob = flip_prob | |
self.instruct = instruct | |
zip_list = [] | |
for i in range(zip_start_index, zip_end_index): | |
name = "shard-"+str(i).zfill(2)+'.zip' | |
zip_list.append(os.path.join(self.path, name)) | |
self.image_dataset = MultipleZipManager(zip_list, 'image', sync=True) # sync=True is faster | |
with open(Path(self.path, "seeds.json")) as f: | |
self.seeds = json.load(f) | |
split_0, split_1 = { | |
"train": (0.0, splits[0]), | |
"val": (splits[0], splits[0] + splits[1]), | |
"test": (splits[0] + splits[1], 1.0), | |
}[split] | |
idx_0 = math.floor(split_0 * len(self.seeds)) | |
idx_1 = math.floor(split_1 * len(self.seeds)) | |
self.seeds = self.seeds[idx_0:idx_1] | |
if max_num_images is not None and max_num_images > 0: | |
self.seeds = self.seeds[:min(max_num_images, len(self.seeds))] | |
# flatten seeds | |
self.seeds = [(name, seed) for name, seeds in self.seeds for seed in seeds] | |
self.sample_weight = sample_weight | |
while True: | |
try: | |
with open('filtered_ids_ip2p.json') as json_file: | |
filtered_ids = json.load(json_file) | |
break | |
except: | |
# download json file from url | |
if reverse_version: | |
os.system('wget https://github.com/TiankaiHang/storage/releases/download/readout/filtered_ids_ip2p.json') | |
else: | |
os.system("wget https://github.com/TiankaiHang/storage/releases/download/readout/filtered-ip2p-thres5.5-0.5.json -O filtered_ids_ip2p.json") | |
print("seeds:", len(self.seeds)) | |
# self.seeds = [seed for seed in self.seeds if seed[1] in filtered_ids] | |
# faster | |
# self.seeds = list(filter(lambda seed: seed[1] in filtered_ids, self.seeds)) | |
# to numpy and faster in parallel | |
# import pdb; pdb.set_trace() | |
_seeds = [f"{a}/{b}" for a, b in self.seeds] | |
self.seeds = np.array(self.seeds) | |
_seeds = np.array(_seeds) | |
self.seeds = self.seeds[np.isin(_seeds, filtered_ids)] | |
self.seeds = self.seeds.tolist() | |
self.return_add_kwargs = kwargs.get("return_add_kwargs", False) | |
def __len__(self) -> int: | |
return int(len(self.seeds) * self.sample_weight) | |
def __getitem__(self, i: int) -> dict[str, Any]: | |
# name, seeds = self.seeds[i] | |
if self.sample_weight >= 1: | |
i = i % len(self.seeds) | |
else: | |
remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight)) | |
i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder) | |
name, seed = self.seeds[i] | |
propt_name = name + "/prompt.json" | |
if not self.image_dataset.managers[self.image_dataset.mapping[propt_name]]._init: | |
self.image_dataset.managers[self.image_dataset.mapping[propt_name]].initialize(close=False) | |
# propt_name = name + "/prompt.json" | |
byteflow = self.image_dataset.managers[self.image_dataset.mapping[propt_name]].zip_fd.read(propt_name) | |
texts = json.loads(byteflow.decode('utf-8')) | |
prompt = texts["edit"] | |
if self.instruct: | |
prompt = "Image Editing: " + prompt | |
text_input = texts["input"] | |
text_output = texts["output"] | |
# image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg")) | |
# image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg")) | |
image_0 = self.image_dataset.get(name+f"/{seed}_0.jpg") | |
image_1 = self.image_dataset.get(name+f"/{seed}_1.jpg") | |
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() | |
image_0 = image_0.resize((reize_res, reize_res), RESAMPLING_METHOD) | |
image_1 = image_1.resize((reize_res, reize_res), RESAMPLING_METHOD) | |
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") | |
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") | |
crop = torchvision.transforms.RandomCrop(self.crop_res) | |
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) | |
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) | |
if self.return_add_kwargs: | |
add_kwargs = dict( | |
name=name, | |
seed=seed, | |
text_input=text_input, | |
text_output=text_output, | |
) | |
else: | |
add_kwargs = {} | |
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt), **add_kwargs) | |
class GIERDataset(Dataset): | |
def __init__( | |
self, | |
path: str, | |
split: str = "train", | |
splits: tuple[float, float, float] = (0.9, 0.05, 0.05), | |
min_resize_res: int = 256, | |
max_resize_res: int = 256, | |
crop_res: int = 256, | |
flip_prob: float = 0.0, | |
zip_start_index: int = 0, | |
zip_end_index: int = 30, | |
sample_weight: float = 1.0, | |
instruct: bool = False, | |
): | |
assert split in ("train", "val", "test") | |
assert sum(splits) == 1 | |
self.path = path | |
self.min_resize_res = min_resize_res | |
self.max_resize_res = max_resize_res | |
self.crop_res = crop_res | |
self.flip_prob = flip_prob | |
self.instruct = instruct | |
# self.meta = torch.load(Path(self.path, "GIER.json"), map_location="cpu") | |
# load json file | |
with open(Path(self.path, "GIER_new.json")) as json_file: | |
self.meta = json.load(json_file) | |
print(f"||||||||||||||||||||||||||||| \n Loaded {len(self.meta)} images from json file") | |
input_does_not_exist = [] | |
output_does_not_exist = [] | |
# filter out out images that do not exist | |
if not os.path.exists(os.path.join(self.path, "filtered_meta_new.pt")): | |
filtered_meta = [] | |
for i in tqdm(range(len(self.meta))): | |
input_path = os.path.join(self.path, "warped", self.meta[i]["input"]) | |
output_path = os.path.join(self.path, "warped", self.meta[i]["output"]) | |
if not os.path.exists(input_path): | |
input_path = os.path.join(self.path, "images", self.meta[i]["input"]) | |
if not os.path.exists(input_path): | |
input_does_not_exist.append(input_path) | |
if not os.path.exists(output_path): | |
output_path = os.path.join(self.path, "images", self.meta[i]["output"]) | |
if not os.path.exists(output_path): | |
output_does_not_exist.append(output_path) | |
if os.path.exists(input_path) and os.path.exists(output_path): | |
filtered_meta.append( | |
dict( | |
input=input_path, | |
output=output_path, | |
prompts=self.meta[i]["prompts"], | |
) | |
) | |
else: | |
print(f"\n {input_path} or {output_path} does not exist") | |
torch.save(filtered_meta, os.path.join(self.path, "filtered_meta_new.pt")) | |
else: | |
filtered_meta = torch.load(os.path.join(self.path, "filtered_meta_new.pt"), map_location="cpu") | |
self.meta = filtered_meta | |
print(f"||||||||||||||||||||||||||||| \n Filtered {len(self.meta)} images") | |
for i in range(len(self.meta)): | |
self.meta[i]['input'] = self.meta[i]['input'].replace('/mnt/external/datasets/GIER_editing_data/', self.path) | |
self.meta[i]['output'] = self.meta[i]['output'].replace('/mnt/external/datasets/GIER_editing_data/', self.path) | |
# write input_does_not_exist and output_does_not_exist to file | |
with open(Path(self.path, f"input_does_not_exist.txt"), "w") as f: | |
for item in input_does_not_exist: | |
f.write("%s\n" % item) | |
with open(Path(self.path, f"output_does_not_exist.txt"), "w") as f: | |
for item in output_does_not_exist: | |
f.write("%s\n" % item) | |
split_0, split_1 = { | |
"train": (0.0, splits[0]), | |
"val": (splits[0], splits[0] + splits[1]), | |
"test": (splits[0] + splits[1], 1.0), | |
}[split] | |
idx_0 = math.floor(split_0 * len(self.meta)) | |
idx_1 = math.floor(split_1 * len(self.meta)) | |
self.meta = self.meta[idx_0:idx_1] | |
self.sample_weight = sample_weight | |
print('original GIER', len(self.meta)) | |
def __len__(self) -> int: | |
return int(len(self.meta) * self.sample_weight) | |
def __getitem__(self, i: int) -> dict[str, Any]: | |
if self.sample_weight >= 1: | |
i = i % len(self.meta) | |
else: | |
i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1) | |
# prompt = self.meta[i]["prompts"] | |
prompt = random.choice(self.meta[i]["prompts"]) | |
try: | |
image_0 = Image.open(self.meta[i]["input"]).convert("RGB") | |
image_1 = Image.open(self.meta[i]["output"]).convert("RGB") | |
except PIL.UnidentifiedImageError: | |
print(f"\n {self.meta[i]['input']} or {self.meta[i]['output']} is not a valid image") | |
i = random.randint(0, len(self.meta) - 1) | |
return self.__getitem__(i) | |
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() | |
image_0 = image_0.resize((reize_res, reize_res), RESAMPLING_METHOD) | |
image_1 = image_1.resize((reize_res, reize_res), RESAMPLING_METHOD) | |
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") | |
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") | |
crop = torchvision.transforms.RandomCrop(self.crop_res) | |
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) | |
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) | |
if self.instruct: | |
prompt = "Image Editing: " + prompt | |
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt)) | |
class GQAInpaintDataset(Dataset): | |
r""" | |
shoud download and unzip the data first | |
``` | |
mkdir -p ../datasets | |
cd ../datasets | |
# if file exists, then skip | |
if [ ! -f "gqa-inpaint.zip" ]; then | |
sudo azcopy copy "https://bingdatawu2.blob.core.windows.net/genrecog/private/t-thang/gqa-inpaint.zip${TOKEN}" . | |
unzip gqa-inpaint.zip -d gqa-inpaint > /dev/null | |
fi | |
if [ ! -f "images.zip" ]; then | |
sudo azcopy copy "https://bingdatawu2.blob.core.windows.net/genrecog/private/t-thang/images.zip${TOKEN}" . | |
unzip images.zip > /dev/null | |
fi | |
``` | |
""" | |
def __init__(self, **kwargs): | |
# load from json ../datasets/gqa-inpaint/meta_info.json | |
self.path = kwargs.get("path", "../datasets/gqa-inpaint") | |
self.instruct = kwargs.get("instruct", False) | |
with open(self.path + "/meta_info.json", "r") as f: | |
self.meta_info = json.load(f) | |
self.min_resize_res = kwargs.get("min_resize_res", 256) | |
self.max_resize_res = kwargs.get("max_resize_res", 256) | |
self.crop_res = kwargs.get("crop_res", 256) | |
self.flip_prob = kwargs.get("flip_prob", 0.5) | |
def __len__(self): | |
return len(self.meta_info) | |
def __getitem__(self, i): | |
item = self.meta_info[i] | |
src_img = Image.open(item["source_image_path"].replace("../datasets", self.path)).convert("RGB") | |
tgt_img = Image.open(item["target_image_path"].replace("../datasets/gqa-inpaint", self.path)).convert("RGB") | |
image_0 = src_img | |
image_1 = tgt_img | |
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() | |
image_0 = image_0.resize((reize_res, reize_res), RESAMPLING_METHOD) | |
image_1 = image_1.resize((reize_res, reize_res), RESAMPLING_METHOD) | |
instruction = item["instruction"] | |
if self.instruct: | |
instruction = "Image Editing: " + instruction | |
# return image_0, image_1, instruction | |
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") | |
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") | |
crop = torchvision.transforms.RandomCrop(self.crop_res) | |
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) | |
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) | |
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=instruction)) | |
class MagicBrushDataset(Dataset): | |
def __init__( | |
self, | |
path: str, | |
split: str = "train", | |
splits: tuple[float, float, float] = (0.9, 0.05, 0.05), | |
min_resize_res: int = 256, | |
max_resize_res: int = 256, | |
crop_res: int = 256, | |
flip_prob: float = 0.0, | |
zip_start_index: int = 0, | |
zip_end_index: int = 30, | |
len_dataset: int = -1, | |
instruct: bool = False, | |
sample_weight: float = 1.0, | |
): | |
assert split in ("train", "val", "test") | |
assert sum(splits) == 1 | |
self.path = path | |
self.min_resize_res = min_resize_res | |
self.max_resize_res = max_resize_res | |
self.crop_res = crop_res | |
self.flip_prob = flip_prob | |
self.instruct = instruct | |
self.sample_weight = sample_weight | |
self.meta_path = os.path.join(self.path, "magic_train.json") | |
with open(self.meta_path, "r") as f: | |
self.meta = json.load(f) | |
def __len__(self) -> int: | |
return int(len(self.meta) * self.sample_weight) | |
def __getitem__(self, i: int) -> dict[str, Any]: | |
if self.sample_weight >= 1: | |
i = i % len(self.meta) | |
else: | |
i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1) | |
item = self.meta[i] | |
try: | |
image_0 = Image.open(os.path.join(self.path, item["input"])).convert("RGB") | |
image_1 = Image.open(os.path.join(self.path, item["edited"])).convert("RGB") | |
except (PIL.UnidentifiedImageError, FileNotFoundError): | |
print(f"\n {self.path}/{item['input']} or {self.path}/{item['edited']} is not a valid image") | |
i = random.randint(0, len(self.meta) - 1) | |
return self.__getitem__(i) | |
prompt = item["instruction"] | |
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() | |
image_0 = image_0.resize((reize_res, reize_res), RESAMPLING_METHOD) | |
image_1 = image_1.resize((reize_res, reize_res), RESAMPLING_METHOD) | |
if self.instruct: | |
prompt = "Image Editing: " + prompt | |
# return image_0, image_1, prompt | |
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") | |
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") | |
crop = torchvision.transforms.RandomCrop(self.crop_res) | |
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) | |
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) | |
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt)) | |
class IEIWDataset(Dataset): | |
def __init__( | |
self, | |
path: str, | |
split: str = "train", | |
splits: tuple[float, float, float] = (0.9, 0.05, 0.05), | |
min_resize_res: int = 256, | |
max_resize_res: int = 256, | |
crop_res: int = 256, | |
flip_prob: float = 0.0, | |
zip_start_index: int = 0, | |
zip_end_index: int = 30, | |
sample_weight: float = 1.0, | |
instruct: bool = False, | |
): | |
assert split in ("train", "val", "test") | |
assert sum(splits) == 1 | |
self.path = path | |
self.min_resize_res = min_resize_res | |
self.max_resize_res = max_resize_res | |
self.crop_res = crop_res | |
self.flip_prob = flip_prob | |
self.instruct = instruct | |
self.meta_path = os.path.join(self.path, "meta_infov1.json") | |
with open(self.meta_path, "r") as f: | |
self.meta = json.load(f) | |
self.sample_weight = sample_weight | |
print('original synthetic', len(self.meta)) | |
def __len__(self) -> int: | |
return int(len(self.meta) * self.sample_weight) | |
def __getitem__(self, i: int) -> dict[str, Any]: | |
if self.sample_weight >= 1: | |
i = i % len(self.meta) | |
else: | |
i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1) | |
item = self.meta[i] | |
item['input'] = item['input'].replace('/mnt/external/tmp/2023/06/11/', self.path) | |
item['edited'] = item['edited'].replace('/mnt/external/tmp/2023/06/11/', self.path) | |
try: | |
image_0 = Image.open(item["input"]).convert("RGB") | |
image_1 = Image.open(item["edited"]).convert("RGB") | |
except (PIL.UnidentifiedImageError, FileNotFoundError): | |
print(f"\n {item['input']} or {item['edited']} is not a valid image") | |
i = random.randint(0, len(self.meta) - 1) | |
return self.__getitem__(i) | |
prompt = item["instruction"] | |
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() | |
image_0 = image_0.resize((reize_res, reize_res), RESAMPLING_METHOD) | |
image_1 = image_1.resize((reize_res, reize_res), RESAMPLING_METHOD) | |
if self.instruct: | |
prompt = "Image Editing: " + prompt | |
# return image_0, image_1, prompt | |
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") | |
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") | |
crop = torchvision.transforms.RandomCrop(self.crop_res) | |
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) | |
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) | |
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt)) | |