from packaging import version |
from PIL import Image |
from torchvision import transforms |
import os |
import PIL |
from torch.utils.data import Dataset |
import torchvision |
import numpy as np |
import torch |
import random |
import albumentations as A |
import copy |
import cv2 |
import pandas as pd |
imagenet_templates_small = [ |
"a photo of a {}", |
"a rendering of a {}", |
"a cropped photo of the {}", |
"the photo of a {}", |
"a photo of a clean {}", |
"a photo of a dirty {}", |
"a dark photo of the {}", |
"a photo of my {}", |
"a photo of the cool {}", |
"a close-up photo of a {}", |
"a bright photo of the {}", |
"a cropped photo of a {}", |
"a photo of the {}", |
"a good photo of the {}", |
"a photo of one {}", |
"a close-up photo of the {}", |
"a rendition of the {}", |
"a photo of the clean {}", |
"a rendition of a {}", |
"a photo of a nice {}", |
"a good photo of a {}", |
"a photo of the nice {}", |
"a photo of the small {}", |
"a photo of the weird {}", |
"a photo of the large {}", |
"a photo of a cool {}", |
"a photo of a small {}", |
] |
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): |
"linear": PIL.Image.Resampling.BILINEAR, |
"bilinear": PIL.Image.Resampling.BILINEAR, |
"bicubic": PIL.Image.Resampling.BICUBIC, |
"lanczos": PIL.Image.Resampling.LANCZOS, |
"nearest": PIL.Image.Resampling.NEAREST, |
} |
else: |
"linear": PIL.Image.LINEAR, |
"bilinear": PIL.Image.BILINEAR, |
"bicubic": PIL.Image.BICUBIC, |
"lanczos": PIL.Image.LANCZOS, |
"nearest": PIL.Image.NEAREST, |
} |
def is_image(file): |
return 'jpg' in file.lower() or 'png' in file.lower() or 'jpeg' in file.lower() |
class CustomDatasetWithBG(Dataset): |
def __init__( |
self, |
data_root, |
tokenizer, |
size=512, |
interpolation="bicubic", |
placeholder_token="*", |
template="a photo of a {}", |
): |
self.data_root = data_root |
self.tokenizer = tokenizer |
self.size = size |
self.placeholder_token = placeholder_token |
self.image_paths = [] |
self.image_paths += [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) if is_image(file_path) and not 'bg' in file_path] |
self.image_paths = sorted(self.image_paths) |
self.num_images = len(self.image_paths) |
self._length = self.num_images |
self.interpolation = { |
"linear": PIL_INTERPOLATION["linear"], |
"bilinear": PIL_INTERPOLATION["bilinear"], |
"bicubic": PIL_INTERPOLATION["bicubic"], |
"lanczos": PIL_INTERPOLATION["lanczos"], |
}[interpolation] |
self.template = template |
def __len__(self): |
return self._length |
def get_tensor_clip(self, normalize=True, toTensor=True): |
transform_list = [] |
if toTensor: |
transform_list += [torchvision.transforms.ToTensor()] |
if normalize: |
transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), |
(0.26862954, 0.26130258, 0.27577711))] |
return torchvision.transforms.Compose(transform_list) |
def process(self, image): |
img = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC) |
img = np.array(img).astype(np.float32) |
img = img / 127.5 - 1.0 |
return torch.from_numpy(img).permute(2, 0, 1) |
def __getitem__(self, i): |
example = {} |
placeholder_string = self.placeholder_token |
text = self.template.format(placeholder_string) |
example["text"] = text |
placeholder_index = 0 |
words = text.strip().split(' ') |
for idx, word in enumerate(words): |
if word == placeholder_string: |
placeholder_index = idx + 1 |
example["index"] = torch.tensor(placeholder_index) |
example["input_ids"] = self.tokenizer( |
text, |
padding="max_length", |
truncation=True, |
max_length=self.tokenizer.model_max_length, |
return_tensors="pt", |
).input_ids[0] |
image = Image.open(self.image_paths[i % self.num_images]) |
mask_path = self.image_paths[i % self.num_images].replace('.jpeg', '.png').replace('.jpg', '.png').replace('.JPEG', '.png')[:-4] + '_bg.png' |
mask = np.array(Image.open(mask_path)) |
mask = np.where(mask > 0, 1, 0) |
if not image.mode == "RGB": |
image = image.convert("RGB") |
image_np = np.array(image) |
object_tensor = image_np * mask |
example["pixel_values"] = self.process(image_np) |
ref_object_tensor = Image.fromarray(object_tensor.astype('uint8')).resize((224, 224), resample=self.interpolation) |
ref_image_tenser = Image.fromarray(image_np.astype('uint8')).resize((224, 224), resample=self.interpolation) |
example["pixel_values_obj"] = self.get_tensor_clip()(ref_object_tensor) |
example["pixel_values_clip"] = self.get_tensor_clip()(ref_image_tenser) |
ref_seg_tensor = Image.fromarray(mask.astype('uint8') * 255) |
ref_seg_tensor = self.get_tensor_clip(normalize=False)(ref_seg_tensor) |
example["pixel_values_seg"] = torch.nn.functional.interpolate(ref_seg_tensor.unsqueeze(0), size=(128, 128), mode='nearest').squeeze(0) |
return example |
class OpenImagesDataset(Dataset): |
def __init__( |
self, |
data_root, |
tokenizer, |
size=512, |
interpolation="bicubic", |
set="train", |
placeholder_token="*", |
): |
self.data_root = data_root |
self.tokenizer = tokenizer |
self.size = size |
self.placeholder_token = placeholder_token |
self.set_type = set |
self.random_trans = A.Compose([ |
A.Resize(height=224, width=224), |
A.HorizontalFlip(p=0.5), |
A.Rotate(limit=20), |
A.Blur(p=0.3), |
A.ElasticTransform(p=0.3) |
]) |
self.bbox_path_list = [] |
if set == "train": |
bboxs_path = os.path.join(data_root, 'annotations', f'oidv6-train-annotations-bbox.csv') |
elif set == "validation": |
bboxs_path = os.path.join(data_root, 'annotations', f'validation-annotations-bbox.csv') |
else: |
bboxs_path = os.path.join(data_root, 'annotations', f'test-annotations-bbox.csv') |
df_val_bbox = pd.read_csv(bboxs_path) |
bbox_groups = df_val_bbox.groupby(df_val_bbox.LabelName) |
bbox_full = [] |
for label_name in df_val_bbox['LabelName'].unique(): |
bboxs = bbox_groups.get_group(label_name)[ |
['XMin', 'XMax', 'YMin', 'YMax', 'LabelName', 'ImageID', |
'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsInside']].values.tolist() |
bboxs_new = [] |
for bbox in bboxs: |
if not ((bbox[1] - bbox[0]) * (bbox[3] - bbox[2]) > 0.8 or (bbox[1] - bbox[0]) * ( |
bbox[3] - bbox[2]) < 0.02): |
bboxs_new.append([bbox[0], bbox[1], bbox[2], bbox[3], bbox[4], bbox[5]]) |
bbox_full.extend(bboxs_new) |
self.bboxs_full = bbox_full |
self.num_images = len(bbox_full) |
print('{}: total {} images ...'.format(set, self.num_images)) |
self._length = self.num_images |
self.interpolation = { |
"linear": PIL_INTERPOLATION["linear"], |
"bilinear": PIL_INTERPOLATION["bilinear"], |
"bicubic": PIL_INTERPOLATION["bicubic"], |
"lanczos": PIL_INTERPOLATION["lanczos"], |
}[interpolation] |
self.templates = imagenet_templates_small |
def __len__(self): |
return self._length |
def get_tensor_clip(self, normalize=True, toTensor=True): |
transform_list = [] |
if toTensor: |
transform_list += [torchvision.transforms.ToTensor()] |
if normalize: |
transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), |
(0.26862954, 0.26130258, 0.27577711))] |
return torchvision.transforms.Compose(transform_list) |
def process(self, image): |
img = np.array(image) |
img = cv2.resize(img, (self.size, self.size), interpolation=cv2.INTER_CUBIC) |
img = np.array(img).astype(np.float32) |
img = img / 127.5 - 1.0 |
return torch.from_numpy(img).permute(2, 0, 1) |
def obtain_text(self, add_caption, object_category=None): |
if object_category is None: |
placeholder_string = self.placeholder_token |
else: |
placeholder_string = object_category |
text = random.choice(self.templates).format(placeholder_string) |
text = add_caption + text[1:] |
placeholder_index = 0 |
words = text.strip().split(' ') |
for idx, word in enumerate(words): |
if word == placeholder_string: |
placeholder_index = idx + 1 |
index = torch.tensor(placeholder_index) |
input_ids = self.tokenizer( |
text, |
padding="max_length", |
truncation=True, |
max_length=self.tokenizer.model_max_length, |
return_tensors="pt", |
).input_ids[0] |
return input_ids, index, text |
def __getitem__(self, i): |
example = {} |
input_ids, index, text = self.obtain_text('a') |
example["input_ids"] = input_ids |
example["index"] = index |
example["text"] = text |
bbox_sample = self.bboxs_full[i % self.num_images] |
bbox_sample = copy.copy(bbox_sample) |
file_name = bbox_sample[-1] + '.jpg' |
img_path = os.path.join(self.data_root, 'images', self.set_type, file_name) |
try: |
img_p = Image.open(img_path).convert("RGB") |
img_p_np = np.array(img_p) |
bbox_sample[0] *= int(img_p_np.shape[1]) |
bbox_sample[1] *= int(img_p_np.shape[1]) |
bbox_sample[2] *= int(img_p_np.shape[0]) |
bbox_sample[3] *= int(img_p_np.shape[0]) |
bbox_pad = copy.copy(bbox_sample) |
bbox_pad[0] = int(bbox_sample[0] - min(10, bbox_sample[0] - 0)) |
bbox_pad[1] = int(bbox_sample[1] + min(10, img_p.size[0] - bbox_sample[1])) |
bbox_pad[2] = int(bbox_sample[2] - min(10, bbox_sample[2] - 0)) |
bbox_pad[3] = int(bbox_sample[3] + min(10, img_p.size[1] - bbox_sample[3])) |
image_tensor = img_p_np[bbox_pad[2]:bbox_pad[3], bbox_pad[0]:bbox_pad[1], :] |
example["pixel_values"] = self.process(image_tensor) |
ref_image_tensor = self.random_trans(image=image_tensor) |
ref_image_tensor = Image.fromarray(ref_image_tensor["image"]) |
example["pixel_values_clip"] = self.get_tensor_clip()(ref_image_tensor) |
except Exception as e: |
example["pixel_values"] = torch.zeros((3, 512, 512)) |
example["pixel_values_clip"] = torch.zeros((3, 224, 224)) |
with open('error.txt', 'a+') as f: |
f.write(str(e) + '\n') |
return example |
class OpenImagesDatasetWithMask(OpenImagesDataset): |
def __init__(self, |
data_root, |
tokenizer, |
size=512, |
interpolation="bicubic", |
set="train", |
placeholder_token="*"): |
self.data_root = data_root |
self.tokenizer = tokenizer |
self.size = size |
self.placeholder_token = placeholder_token |
self.set = set |
class_anno_path = os.path.join(data_root, 'annotations', f'oidv6-class-descriptions.csv') |
anno_files = pd.read_csv(class_anno_path) |
class_groups = anno_files.groupby(anno_files.LabelName) |
if set == "train": |
bboxs_path = os.path.join(data_root, 'annotations', f'train-annotations-object-segmentation.csv') |
dict_path = os.path.join(data_root, 'segs', f'train_bbox_dict.npy') |
elif set == "validation": |
bboxs_path = os.path.join(data_root, 'annotations', f'validation-annotations-object-segmentation.csv') |
dict_path = os.path.join(data_root, 'segs', f'validation_bbox_dict.npy') |
else: |
bboxs_path = os.path.join(data_root, 'annotations', f'test-annotations-object-segmentation.csv') |
dict_path = os.path.join(data_root, 'segs', f'test_bbox_dict.npy') |
bbox_dict = np.load(dict_path, allow_pickle=True).item() |
df_val_bbox = pd.read_csv(bboxs_path) |
bbox_groups = df_val_bbox.groupby(df_val_bbox.LabelName) |
bboxes_full = [] |
for label_name in df_val_bbox['LabelName'].unique(): |
bboxs = bbox_groups.get_group(label_name)[ |
['BoxXMin', 'BoxXMax', 'BoxYMin', 'BoxYMax', 'LabelName', 'MaskPath']].values.tolist() |
bboxes_new = [] |
for box in bboxs: |
if not box[-1] in bbox_dict: |
continue |
bbox_data = bbox_dict[box[-1]] |
if (bbox_data[2] - bbox_data[1]) < 100 or (bbox_data[4] - bbox_data[3]) < 100: |
continue |
if not ((bbox_data[2] - bbox_data[1]) / (bbox_data[4] - bbox_data[3]) < 0.5 or ( |
bbox_data[4] - bbox_data[3]) / ( bbox_data[2] - bbox_data[1]) < 0.5): |
class_name = class_groups.get_group(box[4])[['DisplayName']].values.tolist()[0][0] |
bboxes_new.append([box[-1], bbox_data[1], bbox_data[2], bbox_data[3], bbox_data[4], class_name]) |
bboxes_full.extend(bboxes_new) |
self.bboxes_full = bboxes_full |
self.num_images = len(bboxes_full) |
print('{}: total {} images ...'.format(set, self.num_images)) |
self._length = self.num_images |
self.interpolation = { |
"linear": PIL_INTERPOLATION["linear"], |
"bilinear": PIL_INTERPOLATION["bilinear"], |
"bicubic": PIL_INTERPOLATION["bicubic"], |
"lanczos": PIL_INTERPOLATION["lanczos"], |
}[interpolation] |
self.templates = imagenet_templates_small |
def __len__(self): |
return self._length |
def custom_aug(self, instance_image): |
instance_image = Image.fromarray(instance_image) |
if np.random.randint(0, 3) < 2: |
random_scale = np.random.randint(self.size // 3, self.size + 1) |
else: |
random_scale = np.random.randint(int(1.2 * self.size), int(1.4 * self.size)) |
if random_scale % 2 == 1: |
random_scale += 1 |
if random_scale < 0.6 * self.size: |
add_to_caption = np.random.choice(["a far away", "very small"]) |
cx = np.random.randint(random_scale // 2, self.size - random_scale // 2 + 1) |
cy = np.random.randint(random_scale // 2, self.size - random_scale // 2 + 1) |
instance_image1 = instance_image.resize((random_scale, random_scale), resample=self.interpolation) |
instance_image1 = np.array(instance_image1).astype(np.uint8) |
instance_image1 = (instance_image1 / 127.5 - 1.0).astype(np.float32) |
instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32) |
instance_image[cx - random_scale // 2: cx + random_scale // 2, |
cy - random_scale // 2: cy + random_scale // 2, :] = instance_image1 |
mask = np.zeros((self.size // 8, self.size // 8)) |
mask[(cx - random_scale // 2) // 8 + 1: (cx + random_scale // 2) // 8 - 1, |
(cy - random_scale // 2) // 8 + 1: (cy + random_scale // 2) // 8 - 1] = 1. |
elif random_scale > self.size: |
add_to_caption = np.random.choice(["zoomed in", "close up"]) |
cx = np.random.randint(self.size // 2, random_scale - self.size // 2 + 1) |
cy = np.random.randint(self.size // 2, random_scale - self.size // 2 + 1) |
instance_image = instance_image.resize((random_scale, random_scale), resample=self.interpolation) |
instance_image = np.array(instance_image).astype(np.uint8) |
instance_image = (instance_image / 127.5 - 1.0).astype(np.float32) |
instance_image = instance_image[cx - self.size // 2: cx + self.size // 2, |
cy - self.size // 2: cy + self.size // 2, :] |
mask = np.ones((self.size // 8, self.size // 8)) |
else: |
add_to_caption = "a" |
if self.size is not None: |
instance_image = instance_image.resize((self.size, self.size), resample=self.interpolation) |
instance_image = np.array(instance_image).astype(np.uint8) |
instance_image = (instance_image / 127.5 - 1.0).astype(np.float32) |
mask = np.ones((self.size // 8, self.size // 8)) |
return torch.from_numpy(instance_image).permute(2, 0, 1), torch.from_numpy(mask[:, :, None]).permute(2, 0, 1), add_to_caption |
def aug_cv2(self, img, seg): |
img_auged = np.array(img).copy() |
seg_auged = np.array(seg).copy() |
if random.choice([0, 1]) == 0: |
new_size = random.randint(224, 256) |
img_auged = cv2.resize(img_auged, (new_size, new_size), interpolation=cv2.INTER_CUBIC) |
seg_auged = cv2.resize(seg_auged, (new_size, new_size), interpolation=cv2.INTER_NEAREST) |
start_x, start_y = random.randint(0, new_size - 224), random.randint(0, new_size - 224) |
img_auged = img_auged[start_x:start_x + 224, start_y:start_y + 224, :] |
seg_auged = seg_auged[start_x:start_x + 224, start_y:start_y + 224, :] |
h, w = img_auged.shape[:2] |
if random.choice([0, 1]) == 0: |
angle = random.randint(-30, 30) |
M = cv2.getRotationMatrix2D((112, 112), angle, 1) |
img_auged = cv2.warpAffine(img_auged, M, (w, h), flags=cv2.INTER_CUBIC) |
seg_auged = cv2.warpAffine(seg_auged, M, (w, h), flags=cv2.INTER_NEAREST) |
if random.choice([0, 1]) == 0: |
trans_x = random.randint(-60, 60) |
trans_y = random.randint(-60, 60) |
H = np.float32([[1, 0, trans_x], |
[0, 1, trans_y]]) |
img_auged = cv2.warpAffine(img_auged, H, (w, h), flags=cv2.INTER_CUBIC) |
seg_auged = cv2.warpAffine(seg_auged, H, (w, h), flags=cv2.INTER_NEAREST) |
img_auged = Image.fromarray(img_auged) |
seg_auged = Image.fromarray(seg_auged) |
return img_auged, seg_auged |
def __getitem__(self, i): |
example = {} |
seg_name = self.bboxes_full[i % self.num_images][0] |
file_name = seg_name.split('_')[0] + '.jpg' |
img_path = os.path.join(self.data_root, 'images', self.set, file_name) |
seg_path = os.path.join(self.data_root, 'segs', self.set, seg_name) |
try: |
bbox_sample = self.bboxes_full[i % self.num_images][1:] |
img_p_np = cv2.imread(img_path) |
img_p_np = cv2.cvtColor(img_p_np, cv2.COLOR_BGR2RGB) |
seg_p_np = cv2.imread(seg_path).astype('float') |
seg_p_np = cv2.resize(seg_p_np, img_p_np.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) |
bbox_pad = copy.copy(bbox_sample) |
pad_size = random.choice(list(range(10, 20))) |
bbox_pad[0] = int(bbox_pad[0] - min(pad_size, bbox_pad[0] - 0)) |
bbox_pad[1] = int(bbox_pad[1] + pad_size) |
bbox_pad[2] = int(bbox_pad[2] - min(pad_size, bbox_pad[2] - 0)) |
bbox_pad[3] = int(bbox_pad[3] + pad_size) |
image_tensor = img_p_np[bbox_pad[0]:bbox_pad[1], bbox_pad[2]:bbox_pad[3], :] |
seg_tensor = seg_p_np[bbox_pad[0]:bbox_pad[1], bbox_pad[2]:bbox_pad[3], :] |
augged_image, augged_mask, add_caption = self.custom_aug(image_tensor) |
input_ids, index, text = self.obtain_text(add_caption) |
example["pixel_values"] = augged_image |
example["mask_values"] = augged_mask |
example["input_ids"] = input_ids |
example["index"] = index |
example["text"] = text |
object_tensor = image_tensor * (seg_tensor / 255) |
ref_object_tensor = cv2.resize(object_tensor, (224, 224), interpolation=cv2.INTER_CUBIC) |
ref_image_tenser = cv2.resize(image_tensor, (224, 224), interpolation=cv2.INTER_CUBIC) |
ref_seg_tensor = cv2.resize(seg_tensor, (224, 224), interpolation=cv2.INTER_NEAREST) |
ref_object_tensor, ref_seg_tensor = self.aug_cv2(ref_object_tensor.astype('uint8'), ref_seg_tensor.astype('uint8')) |
example["pixel_values_clip"] = self.get_tensor_clip()(Image.fromarray(ref_image_tenser)) |
example["pixel_values_obj"] = self.get_tensor_clip()(ref_object_tensor) |
example["pixel_values_seg"] = self.get_tensor_clip(normalize=False)(ref_seg_tensor) |
except Exception as e: |
example["pixel_values"] = torch.zeros((3, 512, 512)) |
example["pixel_values_obj"] = torch.zeros((3, 224, 224)) |
example["pixel_values_clip"] = torch.zeros((3, 224, 224)) |
example["pixel_values_seg"] = torch.zeros((3, 224, 224)) |
input_ids, index, text = self.obtain_text("a") |
example["input_ids"] = input_ids |
example["index"] = index |
example["text"] = text |
with open('error.txt', 'a+') as f: |
f.write(str(e) + '\n') |
return example |