LittleApple-fp16's picture
Upload 88 files
4f8ad24
raw
history blame
2.77 kB
import glob
import os
import pathlib
import random
import re
from typing import Iterator
from PIL import UnidentifiedImageError
from imgutils.data import load_image
from .base import RootDataSource
from ..model import ImageItem
class LocalSource(RootDataSource):
def __init__(self, directory: str, recursive: bool = True, shuffle: bool = False):
self.directory = directory
self.recursive = recursive
self.shuffle = shuffle
def _iter_files(self):
if self.recursive:
for directory, _, files in os.walk(self.directory):
group_name = re.sub(r'[\W_]+', '_', directory).strip('_')
for file in files:
yield os.path.join(directory, file), group_name
else:
group_name = re.sub(r'[\W_]+', '_', self.directory).strip('_')
for file in os.listdir(self.directory):
yield os.path.join(self.directory, file), group_name
def _actual_iter_files(self):
lst = list(self._iter_files())
if self.shuffle:
random.shuffle(lst)
yield from lst
def _iter(self) -> Iterator[ImageItem]:
for file, group_name in self._iter_files():
try:
origin_item = ImageItem.load_from_image(file)
origin_item.image.load()
except UnidentifiedImageError:
continue
meta = origin_item.meta or {
'path': os.path.abspath(file),
'group_id': group_name,
'filename': os.path.basename(file),
}
yield ImageItem(origin_item.image, meta)
class LocalTISource(RootDataSource):
def __init__(self, directory: str):
self.directory = directory
def _iter(self) -> Iterator[ImageItem]:
group_name = re.sub(r'[\W_]+', '_', self.directory).strip('_')
for f in glob.glob(os.path.join(self.directory, '*')):
if not os.path.isfile(f):
continue
try:
image = load_image(f)
except UnidentifiedImageError:
continue
id_ = os.path.splitext(os.path.basename(f))[0]
txt_file = os.path.join(self.directory, f'{id_}.txt')
if os.path.exists(txt_file):
full_text = pathlib.Path(txt_file).read_text(encoding='utf-8')
words = re.split(r'\s*,\s*', full_text)
tags = {word: 1.0 for word in words}
else:
tags = {}
meta = {
'path': os.path.abspath(f),
'group_id': group_name,
'filename': os.path.basename(f),
'tags': tags,
}
yield ImageItem(image, meta)