Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# InstructDiffusion | |
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix) | |
# Modified by Binxin Yang ([email protected]) | |
# -------------------------------------------------------- | |
from __future__ import annotations | |
import json | |
import math | |
from pathlib import Path | |
from typing import Any | |
import numpy as np | |
import torch | |
import torchvision | |
from einops import rearrange | |
from PIL import Image | |
from torch.utils.data import Dataset | |
import cv2 | |
import os | |
import random | |
import copy | |
from glob import glob | |
class COCOStuffDataset(Dataset): | |
def __init__( | |
self, | |
path: str, | |
path_edit: str = "None", | |
split: str = "train", | |
splits: tuple[float, float, float] = (0.9, 0.05, 0.05), | |
crop_res: int = 256, | |
flip_prob: float = 0.0, | |
transparency: float = 0, | |
batch_size: int = 10, | |
empty_percentage: float = 0, | |
): | |
assert split in ("train2017", "val2017") | |
assert sum(splits) == 1 | |
self.split = split | |
self.path = path | |
self.path_edit = path_edit | |
self.batch_size = batch_size | |
self.crop_res = crop_res | |
self.flip_prob = flip_prob | |
self.empty_percentage = empty_percentage | |
self.transparency = transparency | |
if self.split in ["train2017", "val2017"]: | |
file_list = sorted(glob(os.path.join(self.path, "images", self.split, "*.jpg"))) | |
assert len(file_list) > 0, "{} has no image".format( | |
os.path.join(self.path, "images", self.split) | |
) | |
file_list = [f.split("/")[-1].replace(".jpg", "") for f in file_list] | |
self.files = file_list | |
else: | |
raise ValueError("Invalid split name: {}".format(self.split)) | |
seg_diverse_prompt_path = 'dataset/prompt/prompt_seg.txt' | |
self.seg_diverse_prompt_list=[] | |
with open(seg_diverse_prompt_path) as f: | |
line=f.readline() | |
while line: | |
line=line.strip('\n') | |
self.seg_diverse_prompt_list.append(line) | |
line=f.readline() | |
color_list_file_path='dataset/prompt/color_list_train_small.txt' | |
self.color_list=[] | |
with open(color_list_file_path) as f: | |
line = f.readline() | |
while line: | |
line_split = line.strip('\n').split(" ") | |
if len(line_split)>1: | |
temp = [] | |
for i in range(4): | |
temp.append(line_split[i]) | |
self.color_list.append(temp) | |
line = f.readline() | |
coco_label_list_path = self.path + '/labels.txt' | |
self.label_dict={} | |
with open(coco_label_list_path) as f: | |
line = f.readline() | |
while line: | |
line_split = line.strip('\n').split(": ") | |
self.label_dict[int(line_split[0])]=line_split[1] | |
line = f.readline() | |
def __len__(self) -> int: | |
length=len(self.files) | |
return length | |
def _augmentation_new(self, image, label): | |
# Cropping | |
h, w = label.shape | |
if h > w: | |
start_h = random.randint(0, h - w) | |
end_h = start_h + w | |
image = image[start_h:end_h] | |
label = label[start_h:end_h] | |
elif h < w: | |
start_w = random.randint(0, w - h) | |
end_w = start_w + h | |
image = image[:, start_w:end_w] | |
label = label[:, start_w:end_w] | |
else: | |
pass | |
image = Image.fromarray(image).resize((self.crop_res, self.crop_res), resample=Image.Resampling.LANCZOS) | |
image = np.asarray(image, dtype=np.uint8) | |
label = Image.fromarray(label).resize((self.crop_res, self.crop_res), resample=Image.Resampling.NEAREST) | |
label = np.asarray(label, dtype=np.int64) | |
return image, label | |
def __getitem__(self, i): | |
image_id = self.files[i] | |
img_path = os.path.join(self.path, "images", self.split, image_id + ".jpg") | |
mask_path = os.path.join(self.path, "annotations", self.split, image_id + ".png") | |
label = Image.open(mask_path).convert("L") | |
image = Image.open(img_path).convert("RGB") | |
label = np.asarray(label) | |
image = np.asarray(image) | |
image, label = self._augmentation_new(image,label) | |
label_list = np.unique(label) | |
label_list = list(label_list) | |
label_list_rest = [i for i in range(182)] | |
for item in label_list_rest: | |
if item in label_list: | |
label_list_rest.remove(item) | |
if 255 in label_list: | |
label_list.remove(255) | |
if len(label_list)!=0: | |
label_idx = random.choice(label_list) | |
if random.uniform(0, 1) < self.empty_percentage: | |
label_idx = random.choice(label_list_rest) | |
class_name = self.label_dict[label_idx+1] | |
prompt = random.choice(self.seg_diverse_prompt_list) | |
color = random.choice(self.color_list) | |
color_name = color[0] | |
prompt = prompt.format(color=color_name.lower(), object=class_name.lower()) | |
R, G, B = color[3].split(",") | |
R = int(R) | |
G = int(G) | |
B = int(B) | |
else: | |
label_idx = 200 | |
prompt = "leave the picture as it is." | |
mask = (label==label_idx) | |
image_0 = Image.fromarray(image) | |
image_1 = copy.deepcopy(image) | |
if len(label_list)!=0: | |
image_1[:,:,0][mask]=self.transparency*image_1[:,:,0][mask]+(1-self.transparency)*R | |
image_1[:,:,1][mask]=self.transparency*image_1[:,:,1][mask]+(1-self.transparency)*G | |
image_1[:,:,2][mask]=self.transparency*image_1[:,:,2][mask]+(1-self.transparency)*B | |
image_1 = Image.fromarray(image_1) | |
# return image_0, image_1, prompt | |
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") | |
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") | |
mask = torch.tensor(mask).float() | |
crop = torchvision.transforms.RandomCrop(self.crop_res) | |
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) | |
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) | |
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt)) |