|
import os |
|
import numpy as np |
|
import PIL |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
|
|
import random |
|
|
|
imagenet_templates_small = [ |
|
'a painting in the style of {}', |
|
'a rendering in the style of {}', |
|
'a cropped painting in the style of {}', |
|
'the painting in the style of {}', |
|
'a clean painting in the style of {}', |
|
'a dirty painting in the style of {}', |
|
'a dark painting in the style of {}', |
|
'a picture in the style of {}', |
|
'a cool painting in the style of {}', |
|
'a close-up painting in the style of {}', |
|
'a bright painting in the style of {}', |
|
'a cropped painting in the style of {}', |
|
'a good painting in the style of {}', |
|
'a close-up painting in the style of {}', |
|
'a rendition in the style of {}', |
|
'a nice painting in the style of {}', |
|
'a small painting in the style of {}', |
|
'a weird painting in the style of {}', |
|
'a large painting in the style of {}', |
|
] |
|
|
|
imagenet_dual_templates_small = [ |
|
'a painting in the style of {} with {}', |
|
'a rendering in the style of {} with {}', |
|
'a cropped painting in the style of {} with {}', |
|
'the painting in the style of {} with {}', |
|
'a clean painting in the style of {} with {}', |
|
'a dirty painting in the style of {} with {}', |
|
'a dark painting in the style of {} with {}', |
|
'a cool painting in the style of {} with {}', |
|
'a close-up painting in the style of {} with {}', |
|
'a bright painting in the style of {} with {}', |
|
'a cropped painting in the style of {} with {}', |
|
'a good painting in the style of {} with {}', |
|
'a painting of one {} in the style of {}', |
|
'a nice painting in the style of {} with {}', |
|
'a small painting in the style of {} with {}', |
|
'a weird painting in the style of {} with {}', |
|
'a large painting in the style of {} with {}', |
|
] |
|
|
|
per_img_token_list = [ |
|
'ื', 'ื', 'ื', 'ื', 'ื', 'ื', 'ื', 'ื', 'ื', 'ื', 'ื', 'ื', 'ื', 'ื ', 'ืก', 'ืข', 'ืค', 'ืฆ', 'ืง', 'ืจ', 'ืฉ', 'ืช', |
|
] |
|
|
|
class PersonalizedBase(Dataset): |
|
def __init__(self, |
|
data_root, |
|
size=None, |
|
repeats=100, |
|
interpolation="bicubic", |
|
flip_p=0.5, |
|
set="train", |
|
placeholder_token="*", |
|
per_image_tokens=False, |
|
center_crop=False, |
|
): |
|
|
|
self.data_root = data_root |
|
|
|
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] |
|
|
|
|
|
self.num_images = len(self.image_paths) |
|
self._length = self.num_images |
|
|
|
self.placeholder_token = placeholder_token |
|
|
|
self.per_image_tokens = per_image_tokens |
|
self.center_crop = center_crop |
|
|
|
if per_image_tokens: |
|
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." |
|
|
|
if set == "train": |
|
self._length = self.num_images * repeats |
|
|
|
self.size = size |
|
self.interpolation = {"linear": PIL.Image.LINEAR, |
|
"bilinear": PIL.Image.BILINEAR, |
|
"bicubic": PIL.Image.BICUBIC, |
|
"lanczos": PIL.Image.LANCZOS, |
|
}[interpolation] |
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p) |
|
|
|
def __len__(self): |
|
return self._length |
|
|
|
def __getitem__(self, i): |
|
example = {} |
|
image = Image.open(self.image_paths[i % self.num_images]) |
|
|
|
image = image.convert('RGBA') |
|
new_image = Image.new('RGBA', image.size, 'WHITE') |
|
new_image.paste(image, (0, 0), image) |
|
image = new_image.convert('RGB') |
|
|
|
templates = [ |
|
'a {} portrait of {}', |
|
'an {} image of {}', |
|
'a {} pretty picture of {}', |
|
'a {} clip art picture of {}', |
|
'an {} illustration of {}', |
|
'a {} 3D render of {}', |
|
'a {} {}', |
|
] |
|
|
|
filename = os.path.basename(self.image_paths[i % self.num_images]) |
|
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').split('-') |
|
filename_tokens = [token for token in filename_tokens if token.isalpha()] |
|
|
|
text = random.choice(templates).format(' '.join(filename_tokens), self.placeholder_token) |
|
print(text) |
|
|
|
example["caption"] = text |
|
|
|
|
|
img = np.array(image).astype(np.uint8) |
|
|
|
if self.center_crop: |
|
crop = min(img.shape[0], img.shape[1]) |
|
h, w, = img.shape[0], img.shape[1] |
|
img = img[(h - crop) // 2:(h + crop) // 2, |
|
(w - crop) // 2:(w + crop) // 2] |
|
|
|
image = Image.fromarray(img) |
|
if self.size is not None: |
|
image = image.resize((self.size, self.size), resample=self.interpolation) |
|
|
|
image = self.flip(image) |
|
image = np.array(image).astype(np.uint8) |
|
example["image"] = (image / 127.5 - 1.0).astype(np.float32) |
|
return example |