Spaces:
Runtime error
Runtime error
import glob | |
import os | |
import pathlib | |
import random | |
import re | |
from typing import Iterator | |
from PIL import UnidentifiedImageError | |
from imgutils.data import load_image | |
from .base import RootDataSource | |
from ..model import ImageItem | |
class LocalSource(RootDataSource): | |
def __init__(self, directory: str, recursive: bool = True, shuffle: bool = False): | |
self.directory = directory | |
self.recursive = recursive | |
self.shuffle = shuffle | |
def _iter_files(self): | |
if self.recursive: | |
for directory, _, files in os.walk(self.directory): | |
group_name = re.sub(r'[\W_]+', '_', directory).strip('_') | |
for file in files: | |
yield os.path.join(directory, file), group_name | |
else: | |
group_name = re.sub(r'[\W_]+', '_', self.directory).strip('_') | |
for file in os.listdir(self.directory): | |
yield os.path.join(self.directory, file), group_name | |
def _actual_iter_files(self): | |
lst = list(self._iter_files()) | |
if self.shuffle: | |
random.shuffle(lst) | |
yield from lst | |
def _iter(self) -> Iterator[ImageItem]: | |
for file, group_name in self._iter_files(): | |
try: | |
origin_item = ImageItem.load_from_image(file) | |
origin_item.image.load() | |
except UnidentifiedImageError: | |
continue | |
meta = origin_item.meta or { | |
'path': os.path.abspath(file), | |
'group_id': group_name, | |
'filename': os.path.basename(file), | |
} | |
yield ImageItem(origin_item.image, meta) | |
class LocalTISource(RootDataSource): | |
def __init__(self, directory: str): | |
self.directory = directory | |
def _iter(self) -> Iterator[ImageItem]: | |
group_name = re.sub(r'[\W_]+', '_', self.directory).strip('_') | |
for f in glob.glob(os.path.join(self.directory, '*')): | |
if not os.path.isfile(f): | |
continue | |
try: | |
image = load_image(f) | |
except UnidentifiedImageError: | |
continue | |
id_ = os.path.splitext(os.path.basename(f))[0] | |
txt_file = os.path.join(self.directory, f'{id_}.txt') | |
if os.path.exists(txt_file): | |
full_text = pathlib.Path(txt_file).read_text(encoding='utf-8') | |
words = re.split(r'\s*,\s*', full_text) | |
tags = {word: 1.0 for word in words} | |
else: | |
tags = {} | |
meta = { | |
'path': os.path.abspath(f), | |
'group_id': group_name, | |
'filename': os.path.basename(f), | |
'tags': tags, | |
} | |
yield ImageItem(image, meta) | |