diff --git a/.bashrc b/.bashrc new file mode 100644 index 0000000000000000000000000000000000000000..61f790fa85a48f729300c2f4fa5bc2d75cedb7d2 --- /dev/null +++ b/.bashrc @@ -0,0 +1 @@ +export PATH=$HOME/.local/bin:$PATH diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c7d9f3332a950355d5a77d85000f05e6f45435ea 100644 --- a/.gitattributes +++ b/.gitattributes @@ -25,7 +25,6 @@ *.safetensors filter=lfs diff=lfs merge=lfs -text saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text *.tflite filter=lfs diff=lfs merge=lfs -text *.tgz filter=lfs diff=lfs merge=lfs -text *.wasm filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5ceb3864c2911029f0a6010fadab352e4b8e2d07 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +venv diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..d9dffd4cb89c189af912179ab9963ee7295893eb --- /dev/null +++ b/Dockerfile @@ -0,0 +1,36 @@ +FROM python:3.8.1 + +WORKDIR /code + +COPY ./requirements.txt /code/requirements.txt + +RUN apt-get update && \ + apt-get install -y sudo tmux wget curl htop make tree && \ + apt-get install -y iputils-ping telnet && \ + apt-get install -y git git-lfs && \ + apt-get install -y libgl1-mesa-glx + +RUN --mount=type=secret,id=PASSWORD,mode=0444,required=true \ + useradd -m -u 1000 user && \ + echo "user:$(cat /run/secrets/PASSWORD)" | chpasswd && \ + adduser user sudo + +RUN pip install -U pip pysocks +RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt + +USER user +ENV HOME=/home/user +ENV PATH=$HOME/.local/bin:$PATH +ENV SHELL=/bin/bash + +WORKDIR $HOME + +COPY --chown=user . $HOME/app + +COPY .bashrc $HOME/.bashrc_append +RUN cat $HOME/.bashrc_append >> $HOME/.bashrc && \ + rm $HOME/.bashrc_append + +EXPOSE 7860 +ENTRYPOINT [] +CMD ["/bin/bash", "./app/run.sh"] diff --git a/README.md b/README.md index e409a133dc366dadefc89ce12226a6690e6ab063..751b86ee69701fc6060834cea622c829796ee93f 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,12 @@ --- -title: AppleJupyter -emoji: 🐠 -colorFrom: pink -colorTo: yellow +title: JupyterLab +emoji: 💹 +colorFrom: blue +colorTo: red sdk: docker pinned: false -license: apache-2.0 +license: mit +app_port: 7860 --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/cyberharem/__init__.py b/cyberharem/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cyberharem/__pycache__/__init__.cpython-310.pyc b/cyberharem/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99c4d8d5bc0e8db1779724aa4de10c602c8ce41a Binary files /dev/null and b/cyberharem/__pycache__/__init__.cpython-310.pyc differ diff --git a/cyberharem/config/__init__.py b/cyberharem/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cyberharem/config/meta.py b/cyberharem/config/meta.py new file mode 100644 index 0000000000000000000000000000000000000000..6c924065b321f1c78332f2b75599ff3a6b8899aa --- /dev/null +++ b/cyberharem/config/meta.py @@ -0,0 +1,19 @@ +""" +Overview: + Meta information for gchar package. +""" + +#: Title of this project (should be `gchar`). +__TITLE__ = 'cyberharem' + +#: Version of this project. +__VERSION__ = '0.0.1' + +#: Short description of the project, will be included in ``setup.py``. +__DESCRIPTION__ = 'Cyber Harem of All the Waifus in Games, Mua~' + +#: Author of this project. +__AUTHOR__ = 'narugo1992' + +#: Email of the authors'. +__AUTHOR_EMAIL__ = 'narugo992@gmail.com' diff --git a/cyberharem/dataset/__init__.py b/cyberharem/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d1a3b52c0829ef4da5ab1500e36b4a06aef944 --- /dev/null +++ b/cyberharem/dataset/__init__.py @@ -0,0 +1,4 @@ +from .crawler import crawl_dataset_to_huggingface, remake_dataset_to_huggingface +from .load import load_dataset_for_character +from .tags import save_recommended_tags, sort_draw_names +from .video import crawl_base_to_huggingface diff --git a/cyberharem/dataset/__main__.py b/cyberharem/dataset/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5064e3df7aa632269faba5ef98a22b641ad8b6 --- /dev/null +++ b/cyberharem/dataset/__main__.py @@ -0,0 +1,38 @@ +from functools import partial +from typing import Optional + +import click +from ditk import logging +from gchar.utils import GLOBAL_CONTEXT_SETTINGS +from gchar.utils import print_version as _origin_print_version + +from .tags import save_recommended_tags + +print_version = partial(_origin_print_version, 'cyberharem.dataset') + + +@click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish trained models') +@click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True) +def cli(): + pass # pragma: no cover + + +@cli.command('retag', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Regenerate tags for given work directory.') +@click.option('-w', '--workdir', 'workdir', type=click.Path(file_okay=False, exists=True), required=True, + help='Work directory for experiment.', show_default=True) +@click.option('-n', '--name', 'name', type=str, default=None, + help='Name of the character.', show_default=True) +def retag(workdir, name: Optional[str] = None): + logging.try_init_root(logging.INFO) + + from ..publish.steps import find_steps_in_workdir + pt_name, _ = find_steps_in_workdir(workdir) + name = name or '_'.join(pt_name.split('_')[:-1]) + + logging.info(f'Regenerate tags for {name!r}, on {workdir!r}.') + save_recommended_tags(name, workdir=workdir) + logging.info('Success!') + + +if __name__ == '__main__': + cli() diff --git a/cyberharem/dataset/__pycache__/__init__.cpython-310.pyc b/cyberharem/dataset/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c158ceb08ed7e85eff6f259eecdfe854c2b78c6f Binary files /dev/null and b/cyberharem/dataset/__pycache__/__init__.cpython-310.pyc differ diff --git a/cyberharem/dataset/__pycache__/crawler.cpython-310.pyc b/cyberharem/dataset/__pycache__/crawler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df96ed4f3a40769ff9046fee687dd1bfa34ac226 Binary files /dev/null and b/cyberharem/dataset/__pycache__/crawler.cpython-310.pyc differ diff --git a/cyberharem/dataset/__pycache__/load.cpython-310.pyc b/cyberharem/dataset/__pycache__/load.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..544e0bf822636b7b3dee688905505c45c44fcc76 Binary files /dev/null and b/cyberharem/dataset/__pycache__/load.cpython-310.pyc differ diff --git a/cyberharem/dataset/__pycache__/tags.cpython-310.pyc b/cyberharem/dataset/__pycache__/tags.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cb88155da2c09546f6e57447c8719218b182fea Binary files /dev/null and b/cyberharem/dataset/__pycache__/tags.cpython-310.pyc differ diff --git a/cyberharem/dataset/crawler.py b/cyberharem/dataset/crawler.py new file mode 100644 index 0000000000000000000000000000000000000000..21c4c8bb761b56fd7b3451f23996ad72f2eff2b5 --- /dev/null +++ b/cyberharem/dataset/crawler.py @@ -0,0 +1,314 @@ +import datetime +import glob +import json +import os.path +import zipfile +from typing import Union, Tuple, List, Optional + +import pandas as pd +from ditk import logging +from gchar.games import get_character +from gchar.games.base import Character +from hbutils.string import plural_word +from hbutils.system import TemporaryDirectory +from huggingface_hub import CommitOperationAdd, hf_hub_url +from waifuc.action import NoMonochromeAction, FilterSimilarAction, \ + TaggingAction, PersonSplitAction, FaceCountAction, CCIPAction, ModeConvertAction, ClassFilterAction, \ + FileOrderAction, RatingFilterAction, BaseAction, RandomFilenameAction, PaddingAlignAction, ThreeStageSplitAction, \ + AlignMinSizeAction, MinSizeFilterAction, FilterAction +from waifuc.action.filter import MinAreaFilterAction +from waifuc.export import SaveExporter, TextualInversionExporter +from waifuc.model import ImageItem +from waifuc.source import GcharAutoSource, BaseDataSource, LocalSource +from waifuc.utils import task_ctx + +from ..utils import number_to_tag, get_ch_name, get_alphabet_name, get_hf_client, download_file, get_hf_fs + + +def get_source(source) -> BaseDataSource: + if isinstance(source, (str, Character)): + source = GcharAutoSource(source, main_sources_count=5) + elif isinstance(source, BaseDataSource): + pass + else: + raise TypeError(f'Unknown source type - {source!r}.') + + return source + + +def get_main_source(source, no_r18: bool = False, bg_color: str = 'white', + no_monochrome_check: bool = False, + drop_multi: bool = True, skip: bool = False) -> BaseDataSource: + source: BaseDataSource = get_source(source) + if not skip: + actions = [ModeConvertAction('RGB', bg_color)] + if not no_monochrome_check: + actions.append(NoMonochromeAction()) # no monochrome, greyscale or sketch + actions.append(ClassFilterAction(['illustration', 'bangumi'])) # no comic or 3d + if no_r18: + actions.append(RatingFilterAction(['safe', 'r15'])) + + actions.append(FilterSimilarAction('all')) # filter duplicated images + if drop_multi: + actions.append(FaceCountAction(count=1, level='n')) # drop images with 0 or >1 faces + actions.extend([ + PersonSplitAction(level='n'), # crop for each person + FaceCountAction(count=1, level='n'), + FileOrderAction(), # Rename files in order + CCIPAction(min_val_count=15), # CCIP, filter the character you may not want to see in dataset + FilterSimilarAction('all'), # filter duplicated images + MinSizeFilterAction(320), + TaggingAction(force=True, character_threshold=1.01), + ]) + actions.append(RandomFilenameAction(ext='.png')) + else: + actions = [] + + return source.attach(*actions) + + +def actions_parse(actions: Union[int, Tuple[int, int], List[BaseAction]], bg_color: str = 'white'): + if isinstance(actions, list): + return actions + elif isinstance(actions, tuple): + width, height = actions + return [PaddingAlignAction((width, height), bg_color)] + elif isinstance(actions, int): + return [AlignMinSizeAction(actions)] + else: + raise TypeError(f'Unknown post action type - {actions!r}.') + + +class CustomMinSizeAction(FilterAction): + def __init__(self, main_size: int = 280, min_eye_size: int = 180): + self.main_size = main_size + self.min_eye_size = min_eye_size + + def check(self, item: ImageItem) -> bool: + min_size = min(item.image.width, item.image.height) + if 'crop' in item.meta and item.meta['crop']['type'] == 'eye': + return min_size >= self.min_eye_size + else: + return min_size >= self.main_size + + +_SOURCES = { + 'native': [ + TaggingAction(force=False, character_threshold=1.01), + ], + 'stage3': [ + ThreeStageSplitAction(split_person=False), + FilterSimilarAction(), + MinSizeFilterAction(280), + TaggingAction(force=False, character_threshold=1.01), + ], + 'stage3-eyes': [ + ThreeStageSplitAction(split_person=False, split_eyes=True), + FilterSimilarAction(), + CustomMinSizeAction(280, 180), + TaggingAction(force=False, character_threshold=1.01), + ] +} + +_DEFAULT_RESOLUTIONS = { + 'raw': ('native', [], 'Raw data with meta information.'), + 'raw-stage3': ('stage3', [], '3-stage cropped raw data with meta information.'), + 'raw-stage3-eyes': ('stage3-eyes', [], '3-stage cropped (with eye-focus) raw data with meta information.'), + '384x512': ('native', (384, 512), '384x512 aligned dataset.'), + # '512x512': ('native', (512, 512), '512x512 aligned dataset.'), + '512x704': ('native', (512, 704), '512x704 aligned dataset.'), + # '640x640': ('native', (640, 640), '640x640 aligned dataset.'), + '640x880': ('native', (640, 880), '640x880 aligned dataset.'), + 'stage3-640': ('stage3', 640, '3-stage cropped dataset with the shorter side not exceeding 640 pixels.'), + 'stage3-800': ('stage3', 800, '3-stage cropped dataset with the shorter side not exceeding 800 pixels.'), + 'stage3-p512-640': ('stage3', [MinAreaFilterAction(512), AlignMinSizeAction(640)], + '3-stage cropped dataset with the area not less than 512x512 pixels.'), + # 'stage3-1200': ('stage3', 1200, '3-stage cropped dataset with the shorter side not exceeding 1200 pixels.'), + 'stage3-eyes-640': ('stage3-eyes', 640, '3-stage cropped (with eye-focus) dataset ' + 'with the shorter side not exceeding 640 pixels.'), + 'stage3-eyes-800': ('stage3-eyes', 800, '3-stage cropped (with eye-focus) dataset ' + 'with the shorter side not exceeding 800 pixels.'), +} + +DATASET_PVERSION = 'v1.4' + + +def crawl_dataset_to_huggingface( + source: Union[str, Character, BaseDataSource], repository: Optional[str] = None, + name: Optional[str] = None, limit: Optional[int] = 1000, min_images: int = 10, + no_r18: bool = False, bg_color: str = 'white', drop_multi: bool = True, skip_preprocess: bool = False, + no_monochrome_check: bool = False, + repo_type: str = 'dataset', revision: str = 'main', path_in_repo: str = '.', private: bool = False, +): + if isinstance(source, (str, Character)): + if isinstance(source, str): + source = get_character(source) + name = f'{source.enname} ({source.__official_name__})' + + if not repository: + repository = f'AppleHarem/{get_ch_name(source)}' + + else: + if name is None: + raise ValueError('Name must be specified when source is not str or character.') + + if not repository: + repository = f'AppleHarem/{get_alphabet_name(name)}' + + origin_source = get_main_source(source, no_r18, bg_color, no_monochrome_check, drop_multi, skip_preprocess) + with TemporaryDirectory() as td: + # save origin directory + origin_dir = os.path.join(td, 'origin') + os.makedirs(origin_dir, exist_ok=True) + if limit is not None: + origin_source = origin_source[:limit] + with task_ctx('origin'): + origin_source.export(SaveExporter(origin_dir)) + + img_count = len(glob.glob(os.path.join(origin_dir, '*.png'))) + if img_count < min_images: + logging.warn(f'Only {plural_word(img_count, "image")} found for {name} which is too few, ' + f'skip post-processing and uploading.') + return + + source_dir = os.path.join(td, 'source') + os.makedirs(source_dir, exist_ok=True) + for sname, actions in _SOURCES.items(): + with task_ctx(f'source/{sname}'): + LocalSource(origin_dir).attach(*actions).export(SaveExporter(os.path.join(source_dir, sname))) + + processed_dir = os.path.join(td, 'processed') + os.makedirs(processed_dir, exist_ok=True) + archive_dir = os.path.join(td, 'archives') + os.makedirs(archive_dir, exist_ok=True) + + files_to_upload: List[Tuple[str, str]] = [] + resolutions = _DEFAULT_RESOLUTIONS + + columns = ['Name', 'Images', 'Download', 'Description'] + rows = [] + for rname, (sname, actions, description) in resolutions.items(): + actions = actions_parse(actions, bg_color) + + ox = LocalSource(os.path.join(source_dir, sname)) + current_processed_dir = os.path.join(processed_dir, rname) + with task_ctx(f'archive/{rname}'): + if not rname.startswith('raw'): # raw is preserved for exporting json data + ox.attach(*actions).export(TextualInversionExporter(current_processed_dir)) + else: + ox.attach(*actions).export(SaveExporter(current_processed_dir)) + current_img_cnt = len(glob.glob(os.path.join(current_processed_dir, '*.png'))) + + zip_file = os.path.join(archive_dir, f'dataset-{rname}.zip') + with zipfile.ZipFile(zip_file, mode='w') as zf: + for directory, _, files in os.walk(current_processed_dir): + for file in files: + file_path = os.path.join(directory, file) + rel_file_path = os.path.relpath(file_path, current_processed_dir) + zf.write( + file_path, + '/'.join(rel_file_path.split(os.sep)) + ) + + rows.append(( + rname, + current_img_cnt, + f'[Download]({os.path.basename(zip_file)})', + description, + )) + + files_to_upload.append((zip_file, os.path.basename(zip_file))) + + meta_file = os.path.join(td, 'meta.json') + with open(meta_file, 'w', encoding='utf-8') as mf: + json.dump({ + 'name': name, + 'version': DATASET_PVERSION, + }, mf, indent=4, sort_keys=True, ensure_ascii=False) + files_to_upload.append((meta_file, 'meta.json')) + + readme_file = os.path.join(td, 'README.md') + with open(readme_file, 'w', encoding='utf-8') as rf: + print(f'---', file=rf) + print(f'license: mit', file=rf) + print(f'task_categories:', file=rf) + print(f'- text-to-image', file=rf) + print(f'tags:', file=rf) + print(f'- art', file=rf) + print(f'- not-for-all-audiences', file=rf) + print(f'size_categories:', file=rf) + print(f'- {number_to_tag(img_count)}', file=rf) + print(f'---', file=rf) + print(f'', file=rf) + + print(f'# Dataset of {name}', file=rf) + print(f'', file=rf) + + print(f'This is the dataset of {name}, ' + f'containing {plural_word(img_count, "images")} and their tags.', file=rf) + print(f'', file=rf) + + print(f'Images are crawled from many sites (e.g. danbooru, pixiv, zerochan ...), ' + f'the auto-crawling system is powered by [DeepGHS Team](https://github.com/deepghs)' + f'([huggingface organization](https://huggingface.co/deepghs)). ', file=rf) + print(f'This is a WebUI contains crawlers and other thing: ' + f'([LittleAppleWebUI](https://github.com/LittleApple-fp16/LittleAppleWebUI))', file=rf) + print(f'', file=rf) + + df = pd.DataFrame(columns=columns, data=rows) + print(df.to_markdown(index=False), file=rf) + print('', file=rf) + + files_to_upload.append((readme_file, 'README.md')) + + hf_client = get_hf_client() + hf_fs = get_hf_fs() + logging.info(f'Initialize repository {repository!r}') + if not hf_fs.exists(f'datasets/{repository}/.gitattributes'): + hf_client.create_repo(repo_id=repository, repo_type=repo_type, exist_ok=True, private=private) + + current_time = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S %Z') + commit_message = f"Publish character {name}, on {current_time}" + logging.info(f'Publishing character {name!r} to repository {repository!r} ...') + hf_client.create_commit( + repository, + [ + CommitOperationAdd( + path_in_repo=f'{path_in_repo}/{filename}', + path_or_fileobj=local_file, + ) for local_file, filename in files_to_upload + ], + commit_message=commit_message, + repo_type=repo_type, + revision=revision, + run_as_future=False, + ) + + +def remake_dataset_to_huggingface( + repository: Optional[str] = None, limit: Optional[int] = 200, min_images: int = 10, + no_r18: bool = False, bg_color: str = 'white', drop_multi: bool = True, + repo_type: str = 'dataset', revision: str = 'main', path_in_repo: str = '.', +): + hf_fs = get_hf_fs() + with TemporaryDirectory() as td: + zip_file = os.path.join(td, 'dataset-raw.zip') + download_file(hf_hub_url(repository, 'dataset-raw.zip', repo_type='dataset'), zip_file) + + source_dir = os.path.join(td, 'source') + os.makedirs(source_dir, exist_ok=True) + with zipfile.ZipFile(zip_file, 'r') as zf: + zf.extractall(source_dir) + + source = LocalSource(source_dir) + name = None + if hf_fs.exists(f'datasets/{repository}/meta.json'): + meta_json = json.loads(hf_fs.read_text(f'datasets/{repository}/meta.json')) + if 'name' in meta_json: + name = meta_json['name'] + name = name or repository.split('/')[-1] + return crawl_dataset_to_huggingface( + source, repository, name, + limit, min_images, no_r18, bg_color, drop_multi, True, + repo_type, revision, path_in_repo + ) diff --git a/cyberharem/dataset/load.py b/cyberharem/dataset/load.py new file mode 100644 index 0000000000000000000000000000000000000000..9855591ea2818103858a4baa3ecb6957b3169558 --- /dev/null +++ b/cyberharem/dataset/load.py @@ -0,0 +1,63 @@ +import logging +import os.path +import zipfile +from contextlib import contextmanager +from typing import ContextManager, Tuple, Optional, Union + +from gchar.games import get_character +from gchar.games.base import Character +from hbutils.system import TemporaryDirectory, urlsplit +from huggingface_hub import hf_hub_url +from waifuc.utils import download_file + +from ..utils import get_hf_fs, get_ch_name + + +@contextmanager +def load_dataset_for_character(source, size: Union[Tuple[int, int], str] = (512, 704)) \ + -> ContextManager[Tuple[Optional[Character], str]]: + if isinstance(source, str) and os.path.exists(source): + if os.path.isdir(source): + logging.info(f'Dataset directory {source!r} loaded.') + yield None, source + elif os.path.isfile(source): + with zipfile.ZipFile(source, 'r') as zf, TemporaryDirectory() as td: + zf.extractall(td) + logging.info(f'Archive dataset {source!r} unzipped to {td!r} and loaded.') + yield None, td + else: + raise OSError(f'Unknown local source - {source!r}.') + + else: + if isinstance(source, Character): + repo = f'AppleHarem/{get_ch_name(source)}' + else: + try_ch = get_character(source) + if try_ch is None: + repo = source + else: + source = try_ch + repo = f'AppleHarem/{get_ch_name(source)}' + + hf_fs = get_hf_fs() + if isinstance(size, tuple): + width, height = size + ds_name = f'{width}x{height}' + elif isinstance(size, str): + ds_name = size + else: + raise TypeError(f'Unknown dataset type - {size!r}.') + if hf_fs.exists(f'datasets/{repo}/dataset-{ds_name}.zip'): + logging.info(f'Online dataset {repo!r} founded.') + zip_url = hf_hub_url(repo_id=repo, repo_type='dataset', filename=f'dataset-{ds_name}.zip') + with TemporaryDirectory() as dltmp: + zip_file = os.path.join(dltmp, 'dataset.zip') + download_file(zip_url, zip_file, desc=f'{repo}/{urlsplit(zip_url).filename}') + + with zipfile.ZipFile(zip_file, 'r') as zf, TemporaryDirectory() as td: + zf.extractall(td) + logging.info(f'Online dataset {repo!r} loaded at {td!r}.') + yield source, td + + else: + raise ValueError(f'Remote dataset {repo!r} not found for {source!r}.') diff --git a/cyberharem/dataset/tags.py b/cyberharem/dataset/tags.py new file mode 100644 index 0000000000000000000000000000000000000000..c5afa62a687f649f919bc05b12e2155520b2fb8b --- /dev/null +++ b/cyberharem/dataset/tags.py @@ -0,0 +1,250 @@ +import json +import os.path +import random +from typing import List + +from gchar.games.base import Character + +from .load import load_dataset_for_character +from ..utils import load_tags_from_directory, get_ch_name, repr_tags + +basic_words = [ + 'best quality', + 'masterpiece', + 'highres', +] + +generic_neg_words = [ + ('worst quality, low quality', 1.4), ('zombie, sketch, interlocked fingers, comic', 1.1), + ('full body', 1.1), 'lowres', 'bad anatomy', 'bad hands', 'text', 'error', 'missing fingers', 'extra digit', + 'fewer digits', 'cropped', 'worst quality', 'low quality', 'normal quality', 'jpeg artifacts', 'signature', + 'watermark', 'username', 'blurry', 'white border', ('english text, chinese text', 1.05), +] + + +def _free_pos_words(generic_words, name, core_tags): + return [ + *generic_words, + (name, 1.15), + *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])], + ], generic_neg_words, None, True + + +def _bikini_pos_words(generic_words, name, core_tags): + return [ + *generic_words, + ('night', 1.1), + ('starry sky', 1.1), + 'beach', + 'beautiful detailed sky', + ('extremely detailed background', 1.2), + (name, 1.15), + ('standing', 1.1), + 'looking at viewer', + ('bikini', 1.3), + *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])], + 'light smile', + ], generic_neg_words, 758691538, True + + +def _nude_pos_words(generic_words, name, core_tags): + return [ + 'nsfw', + *generic_words, + ('lying on bed', 1.1), + ('extremely detailed background', 1.2), + ('nude', 1.4), + ('spread legs', 1.1), + ('arms up', 1.1), + 'mature', + (name, 1.15), + *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])], + 'nipples', + ('pussy', 1.15), + ('pussy juice', 1.3), + 'looking at viewer', + ('embarrassed', 1.1), + 'endured face', + 'feet out of frame', + ], generic_neg_words, 465191133, False + + +def _nude_bondage_words(generic_words, name, core_tags): + return [ + 'nsfw', + *generic_words, + ('simple background', 1.1), + ('standing', 1.15), + ('nude', 1.4), + ('bondage', 1.3), + 'completely nude', + 'mature', + (name, 1.15), + *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])], + 'nipples', + ('pussy', 1.15), + ('pussy juice', 1.3), + 'looking at viewer', + ('embarrassed', 1.1), + ], generic_neg_words, 758691538, False + + +def _nude_stand_words(generic_words, name, core_tags): + return [ + 'nsfw', + *generic_words, + ('simple background', 1.1), + ('standing', 1.15), + ('nude', 1.4), + ('completely nude', 1.2), + 'mature', + (name, 1.15), + *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])], + 'nipples', + ('pussy', 1.15), + ('pussy juice', 1.3), + 'looking at viewer', + ('embarrassed', 1.1), + ], generic_neg_words, 758691538, False + + +def _safe_maid_words(generic_words, name, core_tags): + return [ + *generic_words, + ('maid', 1.4), + ('long maid dress', 1.15), + (name, 1.15), + *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])], + ], [ + 'nsfw', 'sexy', 'underwear', 'bra', 'fishnet', + 'skin of legs', 'bare legs', 'bare skin', 'navel', + *generic_neg_words, + ], None, True + + +def _safe_yukata_words(generic_words, name, core_tags): + return [ + *generic_words, + ('yukata', 1.4), + ('kimono', 1.2), + (name, 1.15), + *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])], + ], [ + 'nsfw', 'sexy', 'underwear', 'bra', 'fishnet', + 'skin of legs', 'bare legs', 'bare skin', 'navel', + *generic_neg_words, + ], None, True + + +def _safe_miko_words(generic_words, name, core_tags): + return [ + *generic_words, + ('white kimono', 1.35), + ('red hakama', 1.35), + ('wide sleeves', 1.2), + (name, 1.15), + *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])], + ], [ + 'nsfw', 'sexy', 'underwear', 'bra', 'fishnet', + 'skin of legs', 'bare legs', 'bare skin', 'navel', + *generic_neg_words, + ], None, True + + +def _safe_suit_words(generic_words, name, core_tags): + return [ + *generic_words, + ('black business suit', 1.4), + ('tie', 1.2), + ('sunglasses', 1.25), + ('white gloves', 1.15), + ('white shirt', 1.1), + ('black skirt', 1.15), + ('smoking', 1.2), + 'handsome', + (name, 1.15), + *[key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])], + ], [ + 'nsfw', 'sexy', 'underwear', 'bra', 'fishnet', + 'skin of legs', 'bare legs', 'bare skin', 'navel', + *generic_neg_words, + ], None, True + + +EXTRAS = [ + ('free', _free_pos_words), + ('bikini', _bikini_pos_words), + ('maid', _safe_maid_words), + ('miko', _safe_miko_words), + ('yukata', _safe_yukata_words), + ('nude', _nude_pos_words), + ('nude2', _nude_stand_words), + ('bondage', _nude_bondage_words), + ('suit', _safe_suit_words), +] + + +def save_recommended_tags(source, name: str = None, workdir: str = None, ds_size: str = '512x704'): + with load_dataset_for_character(source, ds_size) as (ch, ds_dir): + if ch is None: + if name is None: + raise ValueError(f'Name should be specified when using custom source - {source!r}.') + else: + name = name or get_ch_name(ch) + + workdir = workdir or os.path.join('runs', name) + tags_dir = os.path.join(workdir, 'rtags') + os.makedirs(tags_dir, exist_ok=True) + + generic_words = [] + generic_words.extend(basic_words) + if isinstance(ch, Character): + if ch.gender == 'male': + generic_words.extend(['1boy', 'solo']) + elif ch.gender == 'female': + generic_words.extend(['1girl', 'solo']) + else: + generic_words.append('solo') + else: + generic_words.append('solo') + + core_tags, feats = load_tags_from_directory(ds_dir) + for i, f in enumerate(feats, start=1): + pos_words = [*generic_words, (name, 1.15), *f.keys()] + pos_prompt = repr_tags(pos_words) + neg_prompt = repr_tags(generic_neg_words) + + tags_name = f'pattern_{i}' + with open(os.path.join(tags_dir, f'{tags_name}.json'), 'w', encoding='utf-8') as f: + json.dump({ + 'name': tags_name, + 'prompt': pos_prompt, + 'neg_prompt': neg_prompt, + 'seed': random.randint(0, 1 << 31), + 'sfw': True, + }, f, indent=4, ensure_ascii=False) + + for tags_name, _func in EXTRAS: + pos_words, neg_words, seed, is_sfw = _func(generic_words, name, core_tags) + pos_prompt = repr_tags(pos_words) + neg_prompt = repr_tags(neg_words) + + with open(os.path.join(tags_dir, f'{tags_name}.json'), 'w', encoding='utf-8') as f: + json.dump({ + 'name': tags_name, + 'prompt': pos_prompt, + 'neg_prompt': neg_prompt, + 'seed': seed if seed is not None else random.randint(0, 1 << 31), + 'sfw': is_sfw, + }, f, indent=4, ensure_ascii=False) + + +def sort_draw_names(names: List[str]) -> List[str]: + vs = [] + for name in names: + if name.startswith('pattern_'): + vs.append((0, int(name.split('_')[1]), name)) + else: + vs.append((1, name, name)) + + return [item[2] for item in sorted(vs)] diff --git a/cyberharem/dataset/video/__init__.py b/cyberharem/dataset/video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1d7ab0d3daec3ed22e0559be51996eced7097c --- /dev/null +++ b/cyberharem/dataset/video/__init__.py @@ -0,0 +1,2 @@ +from .crawler import crawl_base_to_huggingface +from .extract import extract_to_huggingface diff --git a/cyberharem/dataset/video/__main__.py b/cyberharem/dataset/video/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..80fe89d7ecd4259151befa212962f2efeac76964 --- /dev/null +++ b/cyberharem/dataset/video/__main__.py @@ -0,0 +1,58 @@ +import re +from functools import partial + +import click +from ditk import logging +from gchar.generic import import_generic +from gchar.utils import GLOBAL_CONTEXT_SETTINGS +from gchar.utils import print_version as _origin_print_version +from unidecode import unidecode + +from .bangumibase import sync_bangumi_base +from .extract import extract_to_huggingface + +import_generic() + +print_version = partial(_origin_print_version, 'cyberharem.dataset.video') + + +@click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish video data') +@click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True) +def cli(): + pass # pragma: no cover + + +@cli.command('huggingface', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish to huggingface') +@click.option('--repository', '-r', 'repository', type=str, default=None, + help='Repository to publish to.', show_default=True) +@click.option('--revision', '-R', 'revision', type=str, default='main', + help='Revision for pushing the model.', show_default=True) +@click.option('--input', '-i', 'video_or_directory', type=str, required=True, + help='Input videos.', show_default=True) +@click.option('--name', '-n', 'bangumi_name', type=str, required=True, + help='Bangumi name', show_default=True) +@click.option('--min_size', '-s', 'min_size', type=int, default=320, + help='Min size of image.', show_default=True) +@click.option('--no_extract', '-E', 'no_extract', is_flag=True, type=bool, default=False, + help='No extraction from videos.', show_default=True) +def huggingface(video_or_directory: str, bangumi_name: str, + repository: str, revision: str = 'main', min_size: int = 320, no_extract: bool = False): + logging.try_init_root(logging.INFO) + rname = re.sub(r'[\W_]+', '', unidecode(bangumi_name.lower())) + repository = repository or f"BangumiBase/{rname}" + extract_to_huggingface( + video_or_directory, bangumi_name, repository, revision, + no_extract=no_extract, min_size=min_size + ) + + +@cli.command('bgsync', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Sync index on BangumiBase') +@click.option('--repository', '-r', 'repository', type=str, default='BangumiBase/README', + help='Repository to publish to.', show_default=True) +def bgsync(repository: str): + logging.try_init_root(logging.INFO) + sync_bangumi_base(repository) + + +if __name__ == '__main__': + cli() diff --git a/cyberharem/dataset/video/__pycache__/__init__.cpython-310.pyc b/cyberharem/dataset/video/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd6fb7c08c19f8fa233a7b53a30c3926af2f55f6 Binary files /dev/null and b/cyberharem/dataset/video/__pycache__/__init__.cpython-310.pyc differ diff --git a/cyberharem/dataset/video/__pycache__/crawler.cpython-310.pyc b/cyberharem/dataset/video/__pycache__/crawler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c01c29a7d951337b50317ec57f6379e6dff09b7d Binary files /dev/null and b/cyberharem/dataset/video/__pycache__/crawler.cpython-310.pyc differ diff --git a/cyberharem/dataset/video/__pycache__/extract.cpython-310.pyc b/cyberharem/dataset/video/__pycache__/extract.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0eda852b81586ac5b947cd0950a563ee99e5e13 Binary files /dev/null and b/cyberharem/dataset/video/__pycache__/extract.cpython-310.pyc differ diff --git a/cyberharem/dataset/video/bangumibase.py b/cyberharem/dataset/video/bangumibase.py new file mode 100644 index 0000000000000000000000000000000000000000..402408a85b9410a82c047804a18833a0b9ee9348 --- /dev/null +++ b/cyberharem/dataset/video/bangumibase.py @@ -0,0 +1,149 @@ +import datetime +import fnmatch +import json +import logging +import os.path +import textwrap +from typing import Tuple, Optional + +import dateparser +import pandas as pd +from hbutils.string import plural_word +from hbutils.system import TemporaryDirectory +from huggingface_hub import CommitOperationAdd +from pyquery import PyQuery as pq +from tqdm.auto import tqdm + +from ...utils import get_hf_client, get_hf_fs, get_requests_session, srequest, download_file + +hf_client = get_hf_client() +hf_fs = get_hf_fs() + + +def get_animelist_info(bangumi_name) -> Tuple[Optional[str], Optional[str]]: + session = get_requests_session() + resp = srequest( + session, 'GET', 'https://myanimelist.net/anime.php', + params={ + 'cat': 'anime', + 'q': bangumi_name, + } + ) + table = pq(resp.text)('.js-block-list.list table') + for row in table('tr').items(): + bangumi_url = row('td:nth-child(1) a').attr('href') + if not bangumi_url: + continue + + r = srequest(session, 'GET', bangumi_url) + p = pq(r.text) + post_url = p("img[itemprop=image]").attr('data-src') + if bangumi_url and post_url: + return bangumi_url, post_url + else: + return None, None + + +def sync_bangumi_base(repository: str = 'BangumiBase/README'): + cb_models = [item.modelId for item in hf_client.list_models(author='CyberHarem')] + cb_datasets = [item.id for item in hf_client.list_datasets(author='CyberHarem')] + + with TemporaryDirectory() as td: + readme_file = os.path.join(td, 'README.md') + with open(readme_file, 'w') as f: + rows, total_images, total_clusters, total_animes = [], 0, 0, 0 + for item in tqdm(list(hf_client.list_datasets(author='BangumiBase'))): + if not hf_fs.exists(f'datasets/{item.id}/meta.json'): + logging.info(f'No meta information found for {item.id!r}, skipped') + continue + + meta = json.loads(hf_fs.read_text(f'datasets/{item.id}/meta.json')) + bangumi_name = meta['name'] + safe_bangumi_name = bangumi_name.replace('`', ' ').replace('[', '(').replace(']', ')') + suffix = item.id.split('/')[-1] + datasets_cnt = len([x for x in cb_datasets if fnmatch.fnmatch(x, f'CyberHarem/*_{suffix}')]) + models_cnt = len([x for x in cb_models if fnmatch.fnmatch(x, f'CyberHarem/*_{suffix}')]) + + page_url, post_url = get_animelist_info(bangumi_name) + if post_url: + post_file = os.path.join(td, 'posts', f'{suffix}.jpg') + os.makedirs(os.path.dirname(post_file), exist_ok=True) + download_file(post_url, post_file) + else: + post_file = None + + dataset_url = f'https://huggingface.co/datasets/{item.id}' + post_md = f'![{suffix}]({os.path.relpath(post_file, td)})' if post_file else '(no post)' + if page_url: + post_md = f'[{post_md}]({page_url})' + last_modified = dateparser.parse(item.lastModified) \ + if isinstance(item.lastModified, str) else item.lastModified + rows.append({ + 'Post': post_md, + 'Bangumi': f'[{safe_bangumi_name}]({dataset_url})', + 'Last Modified': last_modified.strftime('%Y-%m-%d %H:%M'), + 'Images': meta['total'], + 'Clusters': len([x for x in meta['ids'] if x != -1]), + 'Datasets': f'[{datasets_cnt}](https://huggingface.co/CyberHarem?' + f'search_models=_{suffix}&search_datasets=_{suffix})', + 'Models': f'[{models_cnt}](https://huggingface.co/CyberHarem?' + f'search_models=_{suffix}&search_datasets=_{suffix})', + }) + total_images += meta['total'] + total_clusters += len([x for x in meta['ids'] if x != -1]) + total_animes += 1 + + print(textwrap.dedent(f""" + --- + title: README + emoji: 🌖 + colorFrom: green + colorTo: red + sdk: static + pinned: false + --- + + ## What is this? + + This is a data hub utilized by the [DeepGHS team](https://huggingface.co/deepghs) for processing + anime series (in video format, including TV, OVA, movies, etc.). + + After downloading anime videos to our GPU cluster, we employ various computer vision algorithms to + extract frames, crop, and **cluster them based on character features**. These processed frames are + then uploaded here to reduce the manual sorting effort required for character images. + + The data in this repository will undergo automated secondary processing to remove noise, + after which it will be packaged and uploaded to [CyberHarem](https://huggingface.co/CyberHarem). + It will then be integrated into an automated pipeline for training character LoRA. + + ## Current Anime Database (constantly updated) + + Last updated on: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M")}, + contains {plural_word(total_animes, "anime")}, {plural_word(total_images, "image")} + and {plural_word(total_clusters, "cluster")} in total. + """).strip(), file=f) + + rows = sorted(rows, key=lambda x: dateparser.parse(x['Last Modified']), reverse=True) + df = pd.DataFrame(rows) + print(df.to_markdown(index=False), file=f) + + operations = [] + for directory, _, files in os.walk(td): + for file in files: + filename = os.path.abspath(os.path.join(directory, file)) + relpath = os.path.relpath(filename, td) + operations.append(CommitOperationAdd( + path_in_repo=relpath, + path_or_fileobj=filename, + )) + + current_time = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S %Z') + commit_message = f'Update lfs images, on {current_time}' + logging.info(f'Updating lfs images to repository {repository!r} ...') + hf_client.create_commit( + repository, + operations, + commit_message=commit_message, + repo_type='space', + revision='main', + ) diff --git a/cyberharem/dataset/video/crawler.py b/cyberharem/dataset/video/crawler.py new file mode 100644 index 0000000000000000000000000000000000000000..8fcbd4f2502a81955d2d202ff3dec1527b26bd28 --- /dev/null +++ b/cyberharem/dataset/video/crawler.py @@ -0,0 +1,70 @@ +import glob +import logging +import os.path +import re +import zipfile +from typing import Optional, Union, List + +from hbutils.system import TemporaryDirectory +from huggingface_hub import hf_hub_url +from unidecode import unidecode +from waifuc.action import CCIPAction, FilterSimilarAction, RandomFilenameAction +from waifuc.source import EmptySource, LocalSource + +from ..crawler import crawl_dataset_to_huggingface +from ...utils import download_file + + +def crawl_base_to_huggingface( + source_repository: str, ch_id: Union[int, List[int]], + name: str, repository: Optional[str] = None, + limit: Optional[int] = 200, min_images: int = 10, + no_r18: bool = False, bg_color: str = 'white', drop_multi: bool = True, + repo_type: str = 'dataset', revision: str = 'main', path_in_repo: str = '.', + skip_preprocess: bool = True, parallel: bool = True, standalone_ccip: bool = True, + keep_cnt_ratio: bool = True, +): + ch_ids = [ch_id] if isinstance(ch_id, int) else ch_id + source = EmptySource() + if not repository: + repository = 'CyberHarem/' + re.sub(r'[\W_]+', '_', unidecode(name.lower())).strip('_').lower() + \ + '_' + source_repository.split('/')[-1] + logging.info(f'Target repository name {repository!r} will be used.') + with TemporaryDirectory() as td: + img_cnts = [] + for cid in ch_ids: + url = hf_hub_url(source_repository, filename=f'{cid}/dataset.zip', repo_type='dataset') + os.makedirs(os.path.join(td, str(cid)), exist_ok=True) + zip_file = os.path.join(td, str(cid), 'dataset.zip') + download_file(url, zip_file) + + source_dir = os.path.join(td, str(cid), 'source') + os.makedirs(source_dir, exist_ok=True) + with zipfile.ZipFile(zip_file, 'r') as zf: + zf.extractall(source_dir) + img_cnts.append(len(glob.glob(os.path.join(source_dir, '*.png')))) + + total = sum(img_cnts) + for cid, c_cnt in zip(ch_ids, img_cnts): + source_dir = os.path.join(td, str(cid), 'source') + new_source = LocalSource(source_dir, shuffle=True) + if standalone_ccip: + new_source = new_source.attach(CCIPAction()) + if keep_cnt_ratio: + new_source = new_source[:int(round(c_cnt * 1.0 / total * limit))] + + if parallel: + source = source | new_source + else: + source = source + new_source + if skip_preprocess: + source = source.attach( + FilterSimilarAction('all'), + RandomFilenameAction(ext='.png'), + ) + + return crawl_dataset_to_huggingface( + source, repository, name, + limit, min_images, no_r18, bg_color, drop_multi, skip_preprocess, + repo_type, revision, path_in_repo + ) diff --git a/cyberharem/dataset/video/extract.py b/cyberharem/dataset/video/extract.py new file mode 100644 index 0000000000000000000000000000000000000000..3b1cda02cc29097bd23b85b3e3ac070c0054b869 --- /dev/null +++ b/cyberharem/dataset/video/extract.py @@ -0,0 +1,334 @@ +import datetime +import glob +import json +import logging +import os.path +import random +import re +import shutil +import zipfile +from contextlib import contextmanager +from textwrap import dedent +from typing import Iterator + +import numpy as np +import pandas as pd +from hbutils.string import plural_word +from hbutils.system import TemporaryDirectory +from huggingface_hub import CommitOperationAdd, CommitOperationDelete +from imgutils.data import load_image +from imgutils.metrics import ccip_extract_feature, ccip_batch_differences, ccip_default_threshold +from natsort import natsorted +from sklearn.cluster import OPTICS +from tqdm.auto import tqdm +from waifuc.action import PaddingAlignAction, PersonSplitAction, FaceCountAction, MinSizeFilterAction, \ + NoMonochromeAction, FilterSimilarAction, HeadCountAction, FileOrderAction, TaggingAction, RandomFilenameAction, \ + BackgroundRemovalAction, ModeConvertAction, FileExtAction +from waifuc.action.filter import MinAreaFilterAction +from waifuc.export import SaveExporter, TextualInversionExporter +from waifuc.model import ImageItem +from waifuc.source import VideoSource, BaseDataSource, LocalSource, EmptySource + +from ...utils import number_to_tag, get_hf_client, get_hf_fs + + +class ListFeatImageSource(BaseDataSource): + def __init__(self, image_files, feats): + self.image_files = image_files + self.feats = feats + + def _iter(self) -> Iterator[ImageItem]: + for file, feat in zip(self.image_files, self.feats): + yield ImageItem(load_image(file), {'ccip_feature': feat, 'filename': os.path.basename(file)}) + + +def cluster_from_directory(src_dir, dst_dir, merge_threshold: float = 0.85, clu_min_samples: int = 5, + extract_from_noise: bool = True): + image_files = np.array(natsorted(glob.glob(os.path.join(src_dir, '*.png')))) + + logging.info(f'Extracting feature of {plural_word(len(image_files), "images")} ...') + images = [ccip_extract_feature(img) for img in tqdm(image_files, desc='Extract features')] + batch_diff = ccip_batch_differences(images) + batch_same = batch_diff <= ccip_default_threshold() + + # clustering + def _metric(x, y): + return batch_diff[int(x), int(y)].item() + + logging.info('Clustering ...') + samples = np.arange(len(images)).reshape(-1, 1) + # max_eps, _ = ccip_default_clustering_params(method='optics_best') + clustering = OPTICS(min_samples=clu_min_samples, metric=_metric).fit(samples) + labels = clustering.labels_ + + max_clu_id = labels.max().item() + all_label_ids = np.array([-1, *range(0, max_clu_id + 1)]) + logging.info(f'Cluster complete, with {plural_word(max_clu_id, "cluster")}.') + label_cnt = {i: (labels == i).sum() for i in all_label_ids if (labels == i).sum() > 0} + logging.info(f'Current label count: {label_cnt}') + + if extract_from_noise: + mask_labels = labels.copy() + for nid in tqdm(np.where(labels == -1)[0], desc='Matching for noises'): + avg_dists = np.array([ + batch_diff[nid][labels == i].mean() + for i in range(0, max_clu_id + 1) + ]) + r_sames = np.array([ + batch_same[nid][labels == i].mean() + for i in range(0, max_clu_id + 1) + ]) + best_id = np.argmin(avg_dists) + if r_sames[best_id] >= 0.90: + mask_labels[nid] = best_id + labels = mask_labels + logging.info('Noise extracting complete.') + label_cnt = {i: (labels == i).sum() for i in all_label_ids if (labels == i).sum() > 0} + logging.info(f'Current label count: {label_cnt}') + + # trying to merge clusters + _exist_ids = set(range(0, max_clu_id + 1)) + while True: + _round_merged = False + for xi in range(0, max_clu_id + 1): + if xi not in _exist_ids: + continue + for yi in range(xi + 1, max_clu_id + 1): + if yi not in _exist_ids: + continue + + score = (batch_same[labels == xi][:, labels == yi]).mean() + logging.info(f'Label {xi} and {yi}\'s similarity score: {score}') + if score >= merge_threshold: + labels[labels == yi] = xi + logging.info(f'Merging label {yi} into {xi} ...') + _exist_ids.remove(yi) + _round_merged = True + + if not _round_merged: + break + + logging.info(f'Merge complete, remained cluster ids: {sorted(_exist_ids)}.') + label_cnt = {i: (labels == i).sum() for i in all_label_ids if (labels == i).sum() > 0} + logging.info(f'Current label count: {label_cnt}') + ids = [] + for i, clu_id in enumerate(tqdm(sorted(_exist_ids))): + total = (labels == clu_id).sum() + logging.info(f'Cluster {clu_id} will be renamed as #{i}, {plural_word(total, "image")} in total.') + os.makedirs(os.path.join(dst_dir, str(i)), exist_ok=True) + for imgfile in image_files[labels == clu_id]: + shutil.copyfile(imgfile, os.path.join(dst_dir, str(i), os.path.basename(imgfile))) + ids.append(i) + + n_total = (labels == -1).sum() + if n_total > 0: + logging.info(f'Save noise images, {plural_word(n_total, "image")} in total.') + os.makedirs(os.path.join(dst_dir, str(-1)), exist_ok=True) + for imgfile in image_files[labels == -1]: + shutil.copyfile(imgfile, os.path.join(dst_dir, str(-1), os.path.basename(imgfile))) + ids.append(-1) + + return ids + + +def create_project_by_result(bangumi_name: str, ids, clu_dir, dst_dir, preview_count: int = 8, regsize: int = 1000): + total_image_cnt = 0 + columns = ['#', 'Images', 'Download', *(f'Preview {i}' for i in range(1, preview_count + 1))] + rows = [] + reg_source = EmptySource() + for id_ in ids: + logging.info(f'Packing for #{id_} ...') + person_dir = os.path.join(dst_dir, str(id_)) + new_reg_source = LocalSource(os.path.join(clu_dir, str(id_)), shuffle=True).attach( + MinAreaFilterAction(400) + ) + reg_source = reg_source | new_reg_source + os.makedirs(person_dir, exist_ok=True) + with zipfile.ZipFile(os.path.join(person_dir, 'dataset.zip'), 'w') as zf: + all_person_images = glob.glob(os.path.join(clu_dir, str(id_), '*.png')) + total_image_cnt += len(all_person_images) + for file in all_person_images: + zf.write(file, os.path.basename(file)) + + for i, file in enumerate(random.sample(all_person_images, k=min(len(all_person_images), preview_count)), + start=1): + PaddingAlignAction((512, 704))(ImageItem(load_image(file))) \ + .image.save(os.path.join(person_dir, f'preview_{i}.png')) + + rel_zip_path = os.path.relpath(os.path.join(person_dir, 'dataset.zip'), dst_dir) + row = [id_ if id_ != -1 else 'noise', len(all_person_images), f'[Download]({rel_zip_path})'] + for i in range(1, preview_count + 1): + if os.path.exists(os.path.join(person_dir, f'preview_{i}.png')): + relpath = os.path.relpath(os.path.join(person_dir, f'preview_{i}.png'), dst_dir) + row.append(f'![preview {i}]({relpath})') + else: + row.append('N/A') + rows.append(row) + + with TemporaryDirectory() as td: + logging.info('Creating regular normal dataset ...') + reg_source.attach( + TaggingAction(force=False, character_threshold=1.01), + RandomFilenameAction(), + )[:regsize].export(TextualInversionExporter(td)) + + logging.info('Packing regular normal dataset ...') + reg_zip = os.path.join(dst_dir, 'regular', 'normal.zip') + os.makedirs(os.path.dirname(reg_zip), exist_ok=True) + with zipfile.ZipFile(reg_zip, 'w') as zf: + for file in glob.glob(os.path.join(td, '*')): + zf.write(file, os.path.relpath(file, td)) + + with TemporaryDirectory() as td_nobg: + logging.info('Creating regular no-background dataset ...') + LocalSource(td).attach( + BackgroundRemovalAction(), + ModeConvertAction('RGB', 'white'), + TaggingAction(force=True, character_threshold=1.01), + FileExtAction('.png'), + ).export(TextualInversionExporter(td_nobg)) + + logging.info('Packing regular no-background dataset ...') + reg_nobg_zip = os.path.join(dst_dir, 'regular', 'nobg.zip') + os.makedirs(os.path.dirname(reg_nobg_zip), exist_ok=True) + with zipfile.ZipFile(reg_nobg_zip, 'w') as zf: + for file in glob.glob(os.path.join(td_nobg, '*')): + zf.write(file, os.path.relpath(file, td_nobg)) + + logging.info('Packing all images ...') + all_zip = os.path.join(dst_dir, 'all.zip') + with zipfile.ZipFile(all_zip, 'w') as zf: + for file in glob.glob(os.path.join(clu_dir, '*', '*.png')): + zf.write(file, os.path.relpath(file, clu_dir)) + + logging.info('Packing raw package ...') + raw_zip = os.path.join(dst_dir, 'raw.zip') + with zipfile.ZipFile(raw_zip, 'w') as zf: + for file in glob.glob(os.path.join(clu_dir, '*', '*.png')): + zf.write(file, os.path.basename(file)) + + with open(os.path.join(dst_dir, 'meta.json'), 'w', encoding='utf-8') as f: + json.dump({ + 'name': bangumi_name, + 'ids': ids, + 'total': total_image_cnt, + }, f, indent=4, sort_keys=True, ensure_ascii=False) + + with open(os.path.join(dst_dir, 'README.md'), 'w', encoding='utf-8') as f: + print(dedent(f""" + --- + license: mit + tags: + - art + size_categories: + - {number_to_tag(total_image_cnt)} + --- + """).strip(), file=f) + print('', file=f) + + c_name = ' '.join(map(str.capitalize, re.split(r'\s+', bangumi_name))) + print(f'# Bangumi Image Base of {c_name}', file=f) + print('', file=f) + + print(f'This is the image base of bangumi {bangumi_name}, ' + f'we detected {plural_word(len(ids), "character")}, ' + f'{plural_word(total_image_cnt, "images")} in total. ' + f'The full dataset is [here]({os.path.relpath(all_zip, dst_dir)}).', file=f) + print('', file=f) + + print(f'**Please note that these image bases are not guaranteed to be 100% cleaned, ' + f'they may be noisy actual.** If you intend to manually train models using this dataset, ' + f'we recommend performing necessary preprocessing on the downloaded dataset to eliminate ' + f'potential noisy samples (approximately 1% probability).', file=f) + print('', file=f) + + print(f'Here is the characters\' preview:', file=f) + print('', file=f) + + df = pd.DataFrame(columns=columns, data=rows) + print(df.to_markdown(index=False), file=f) + print('', file=f) + + +@contextmanager +def extract_from_videos(video_or_directory: str, bangumi_name: str, no_extract: bool = False, + min_size: int = 320, merge_threshold: float = 0.85, preview_count: int = 8): + if no_extract: + source = LocalSource(video_or_directory) + else: + if os.path.isfile(video_or_directory): + source = VideoSource(video_or_directory) + elif os.path.isdir(video_or_directory): + source = VideoSource.from_directory(video_or_directory) + else: + raise TypeError(f'Unknown video - {video_or_directory!r}.') + + source = source.attach( + NoMonochromeAction(), + PersonSplitAction(keep_original=False, level='n'), + FaceCountAction(1, level='n'), + HeadCountAction(1, level='n'), + MinSizeFilterAction(min_size), + FilterSimilarAction('all'), + FileOrderAction(ext='.png'), + ) + + with TemporaryDirectory() as src_dir: + logging.info('Extract figures from videos ...') + source.export(SaveExporter(src_dir, no_meta=True)) + + with TemporaryDirectory() as clu_dir: + logging.info(f'Clustering from {src_dir!r} to {clu_dir!r} ...') + ids = cluster_from_directory(src_dir, clu_dir, merge_threshold) + + with TemporaryDirectory() as dst_dir: + create_project_by_result(bangumi_name, ids, clu_dir, dst_dir, preview_count) + + yield dst_dir + + +def extract_to_huggingface(video_or_directory: str, bangumi_name: str, + repository: str, revision: str = 'main', no_extract: bool = False, + min_size: int = 320, merge_threshold: float = 0.85, preview_count: int = 8): + logging.info(f'Initializing repository {repository!r} ...') + hf_client = get_hf_client() + hf_fs = get_hf_fs() + if not hf_fs.exists(f'datasets/{repository}/.gitattributes'): + hf_client.create_repo(repo_id=repository, repo_type='dataset', exist_ok=True) + + _exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')] + _exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1]) + pre_exist_files = set() + for i, (file, segments) in enumerate(_exist_ps): + if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]: + continue + if file != '.': + pre_exist_files.add(file) + + with extract_from_videos(video_or_directory, bangumi_name, no_extract, + min_size, merge_threshold, preview_count) as dst_dir: + operations = [] + for directory, _, files in os.walk(dst_dir): + for file in files: + filename = os.path.abspath(os.path.join(dst_dir, directory, file)) + file_in_repo = os.path.relpath(filename, dst_dir) + operations.append(CommitOperationAdd( + path_in_repo=file_in_repo, + path_or_fileobj=filename, + )) + if file_in_repo in pre_exist_files: + pre_exist_files.remove(file_in_repo) + logging.info(f'Useless files: {sorted(pre_exist_files)} ...') + for file in sorted(pre_exist_files): + operations.append(CommitOperationDelete(path_in_repo=file)) + + current_time = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S %Z') + commit_message = f'Publish {bangumi_name}\'s data, on {current_time}' + logging.info(f'Publishing {bangumi_name}\'s data to repository {repository!r} ...') + hf_client.create_commit( + repository, + operations, + commit_message=commit_message, + repo_type='dataset', + revision=revision, + ) diff --git a/cyberharem/infer/__init__.py b/cyberharem/infer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..283d1b934bad296578f886b88938402ce6c585cb --- /dev/null +++ b/cyberharem/infer/__init__.py @@ -0,0 +1,3 @@ +from .civitai import publish_samples_to_civitai, civitai_review, civitai_auto_review +from .draw import draw_images, draw_with_workdir +from .export import draw_to_directory, draw_with_repo diff --git a/cyberharem/infer/__pycache__/__init__.cpython-310.pyc b/cyberharem/infer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aef8f0250b743c1bcf83fe4bbf04de699d73f05 Binary files /dev/null and b/cyberharem/infer/__pycache__/__init__.cpython-310.pyc differ diff --git a/cyberharem/infer/__pycache__/civitai.cpython-310.pyc b/cyberharem/infer/__pycache__/civitai.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..263e38e379ae3d57b907da2b80870036af924c84 Binary files /dev/null and b/cyberharem/infer/__pycache__/civitai.cpython-310.pyc differ diff --git a/cyberharem/infer/__pycache__/draw.cpython-310.pyc b/cyberharem/infer/__pycache__/draw.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..669d91e5d158dd306ec49f2af4b04e4704d0676e Binary files /dev/null and b/cyberharem/infer/__pycache__/draw.cpython-310.pyc differ diff --git a/cyberharem/infer/__pycache__/export.cpython-310.pyc b/cyberharem/infer/__pycache__/export.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0697a94e412e468cc096a107d808f4029f2f1192 Binary files /dev/null and b/cyberharem/infer/__pycache__/export.cpython-310.pyc differ diff --git a/cyberharem/infer/civitai.py b/cyberharem/infer/civitai.py new file mode 100644 index 0000000000000000000000000000000000000000..1975a29f476042e4703abe300e72594f648ab923 --- /dev/null +++ b/cyberharem/infer/civitai.py @@ -0,0 +1,384 @@ +import glob +import io +import json +import logging +import os +import re +import textwrap +from typing import Union, Optional, List + +import markdown2 +import numpy as np +from PIL import Image +from hbutils.string import plural_word +from hbutils.system import TemporaryDirectory +from imgutils.data import load_image +from imgutils.detect import detect_faces +from imgutils.metrics import ccip_extract_feature, ccip_batch_differences, ccip_default_threshold +from imgutils.validate import anime_rating_score +from pycivitai import civitai_find_online +from pycivitai.client import find_version_id_by_hash +from tqdm.auto import tqdm +from waifuc.source import LocalSource + +from .export import draw_with_repo +from ..dataset import load_dataset_for_character +from ..publish.civitai import _tag_decode, try_find_title, try_get_title_from_repo +from ..utils import srequest, get_hf_fs, load_tags_from_directory + + +def publish_samples_to_civitai(images_dir, model: Union[int, str], model_version: Optional[str] = None, + model_creator='narugo1992', safe_only: bool = False, + extra_tags: Optional[List[str]] = None, post_title: str = None, + session_repo: str = 'narugo/civitai_session_p1'): + resource = civitai_find_online(model, model_version, creator=model_creator) + model_version_id = resource.version_id + post_title = post_title or f"{resource.model_name} - {resource.version_name} Review" + + images = [] + for img_file in glob.glob(os.path.join(images_dir, '*.png')): + img_filename = os.path.basename(img_file) + img_name = os.path.splitext(img_filename)[0] + img_info_filename = f'{img_name}_info.txt' + + local_img_file = os.path.join(images_dir, img_filename) + local_info_file = os.path.join(images_dir, img_info_filename) + + info = {} + with open(local_info_file, 'r', encoding='utf-8') as iif: + for line in iif: + line = line.strip() + if line: + info_name, info_text = line.split(':', maxsplit=1) + info[info_name.strip()] = info_text.strip() + + meta = { + 'cfgScale': int(round(float(info.get('Guidance Scale')))), + 'negativePrompt': info.get('Neg Prompt'), + 'prompt': info.get('Prompt'), + 'sampler': info.get('Sample Method', "Euler a"), + 'seed': int(info.get('Seed')), + 'steps': int(info.get('Infer Steps')), + 'Size': f"{info['Width']}x{info['Height']}", + } + if info.get('Clip Skip'): + meta['clipSkip'] = int(info['Clip Skip']) + if info.get('Model'): + meta['Model'] = info['Model'] + pil_img_file = Image.open(local_img_file) + if pil_img_file.info.get('parameters'): + png_info_text = pil_img_file.info['parameters'] + find_hash = re.findall(r'Model hash:\s*([a-zA-Z\d]+)', png_info_text, re.IGNORECASE) + if find_hash: + model_hash = find_hash[0].lower() + meta['hashes'] = {"model": model_hash} + meta["resources"] = [ + { + "hash": model_hash, + "name": info['Model'], + "type": "model" + } + ] + meta["Model hash"] = model_hash + + nsfw = (info.get('Safe For Word', info.get('Safe For Work')) or '').lower() != 'yes' + + rating_score = anime_rating_score(local_img_file) + safe_v = int(round(rating_score['safe'] * 10)) + safe_r15 = int(round(rating_score['r15'] * 10)) + safe_r18 = int(round(rating_score['r18'] * 10)) + faces = detect_faces(local_img_file) + if faces: + (x0, y0, x1, y1), _, _ = faces[0] + width, height = load_image(local_img_file).size + face_area = abs((x1 - x0) * (y1 - y0)) + face_ratio = face_area * 1.0 / (width * height) + face_ratio = int(round(face_ratio * 50)) + else: + continue + + images.append(( + (-safe_v, -safe_r15, -safe_r18) if safe_only else (0,), + -face_ratio, + 1 if nsfw else 0, + 0 if img_name.startswith('pattern_') else 1, + img_name, + (local_img_file, img_filename, meta) + )) + + images = [item[-1] for item in sorted(images)] + + from ..publish.civitai import civitai_upload_images, get_civitai_session, parse_publish_at + + def _custom_pc_func(mvid): + return { + "json": { + "modelVersionId": mvid, + "title": post_title, + "tag": None, + "authed": True, + }, + "meta": { + "values": { + "tag": ["undefined"] + } + } + } + + session = get_civitai_session(session_repo) + post_id = civitai_upload_images( + model_version_id, images, + tags=[*resource.tags, *extra_tags], + model_id=resource.model_id, + pc_func=_custom_pc_func, + session=session, + ) + + logging.info(f'Publishing post {post_id!r} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/post.update', + json={ + "json": { + "id": post_id, + "publishedAt": parse_publish_at('now'), + "authed": True, + }, + "meta": { + "values": { + "publishedAt": ["Date"] + } + } + }, + headers={'Referer': f'https://civitai.com/models/{resource.model_id}/wizard?step=4'}, + ) + resp.raise_for_status() + + return images + + +def civitai_review(model: Union[int, str], model_version: Optional[str] = None, + model_creator='narugo1992', rating: int = 5, description_md: Optional[str] = None, + session_repo: str = 'narugo/civitai_session_p1'): + resource = civitai_find_online(model, model_version, creator=model_creator) + + from ..publish.civitai import get_civitai_session + session = get_civitai_session(session_repo) + + logging.info(f'Try find exist review of model version #{resource.version_id} ...') + _err = None + try: # Add this shit for the 500 of this API (2023-09-14) + resp = srequest( + session, 'GET', 'https://civitai.com/api/trpc/resourceReview.getUserResourceReview', + params={'input': json.dumps({"json": {"modelVersionId": resource.version_id, "authed": True}})}, + headers={ + 'Referer': f'https://civitai.com/posts/create?modelId={resource.model_id}&' + f'modelVersionId={resource.version_id}&' + f'returnUrl=/models/{resource.model_id}?' + f'modelVersionId={resource.version_id}reviewing=true' + }, + raise_for_status=False + ) + except AssertionError: + _err = True + resp = None + + if _err or resp.status_code == 404: + logging.info(f'Creating review for #{resource.version_id} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/resourceReview.create', + json={ + "json": { + "modelVersionId": resource.version_id, + "modelId": resource.model_id, + "rating": rating, + "authed": True, + } + }, + headers={'Referer': f'https://civitai.com/models/{resource.model_id}/wizard?step=4'} + ) + resp.raise_for_status() + else: + if resp is not None: + resp.raise_for_status() + review_id = resp.json()['result']['data']['json']['id'] + + logging.info(f'Updating review #{review_id}\'s rating ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/resourceReview.update', + json={ + "json": { + "id": review_id, + "rating": rating, + "details": None, + "authed": True, + }, + "meta": {"values": {"details": ["undefined"]}} + }, + headers={'Referer': f'https://civitai.com/models/{resource.model_id}/wizard?step=4'} + ) + resp.raise_for_status() + + if description_md: + logging.info(f'Updating review #{review_id}\'s description ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/resourceReview.update', + json={ + "json": { + "id": review_id, + "details": markdown2.markdown(textwrap.dedent(description_md)), + 'rating': None, + "authed": True, + }, + "meta": {"values": {"rating": ["undefined"]}} + }, + headers={'Referer': f'https://civitai.com/models/{resource.model_id}/wizard?step=4'} + ) + resp.raise_for_status() + + +_BASE_MODEL_LIST = [ + 'AIARTCHAN/anidosmixV2', + # 'stablediffusionapi/anything-v5', + # 'Lykon/DreamShaper', + 'Meina/Unreal_V4.1', + 'digiplay/majicMIX_realistic_v6', + 'jzli/XXMix_9realistic-v4', + 'stablediffusionapi/abyssorangemix2nsfw', + 'AIARTCHAN/expmixLine_v2', + # 'Yntec/CuteYuki2', + 'stablediffusionapi/counterfeit-v30', + 'stablediffusionapi/flat-2d-animerge', + 'redstonehero/cetusmix_v4', + # 'KBlueLeaf/kohaku-v4-rev1.2', + # 'stablediffusionapi/night-sky-yozora-sty', + 'Meina/MeinaHentai_V4', + # 'Meina/MeinaPastel_V6', +] + + +def civitai_auto_review(repository: str, model: Optional[Union[int, str]] = None, + model_version: Optional[str] = None, + model_creator='narugo1992', step: Optional[int] = None, + base_models: Optional[List[str]] = None, + rating: Optional[int] = 5, description_md: Optional[str] = None, + session_repo: str = 'narugo/civitai_session_p1'): + game_name = repository.split('/')[-1].split('_')[-1] + char_name = ' '.join(repository.split('/')[-1].split('_')[:-1]) + model = model or try_find_title(char_name, game_name) or \ + try_get_title_from_repo(repository) or repository.split('/')[-1] + logging.info(f'Model name on civitai: {model!r}') + + from ..publish.export import KNOWN_MODEL_HASHES + + hf_fs = get_hf_fs() + model_info = json.loads(hf_fs.read_text(f'{repository}/meta.json')) + dataset_info = model_info['dataset'] + + # load dataset + ds_size = (384, 512) if not dataset_info or not dataset_info['type'] else dataset_info['type'] + with load_dataset_for_character(repository, size=ds_size) as (_, ds_dir): + core_tags, _ = load_tags_from_directory(ds_dir) + + all_tags = [ + game_name, f"{game_name} {char_name}", char_name, + 'female', 'girl', 'character', 'fully-automated', 'random prompt', 'random seed', + *map(_tag_decode, core_tags.keys()), + ] + ds_source = LocalSource(ds_dir) + ds_feats = [] + for item in tqdm(list(ds_source), desc='Extract Dataset Feature'): + ds_feats.append(ccip_extract_feature(item.image)) + + all_feats = [] + model_results = [] + for base_model in (base_models or _BASE_MODEL_LIST): + logging.info(f'Reviewing with {base_model!r} ...') + with TemporaryDirectory() as td: + if KNOWN_MODEL_HASHES.get(base_model): + bm_id, bm_version_id, _ = find_version_id_by_hash(KNOWN_MODEL_HASHES[base_model]) + resource = civitai_find_online(bm_id, bm_version_id) + m_name = f'{resource.model_name} - {resource.version_name}' + m_url = f'https://civitai.com/models/{resource.model_id}?modelVersionId={resource.version_id}' + else: + m_name = base_model + m_url = None + + draw_with_repo(repository, td, step=step, pretrained_model=base_model) + images = publish_samples_to_civitai( + td, model, model_version, + model_creator=model_creator, + extra_tags=all_tags, + post_title=f"AI Review (Base Model: {m_name})", + session_repo=session_repo + ) + + images_count = len(images) + gp_feats = [] + for local_imgfile, _, _ in tqdm(images, desc='Extract Images Feature'): + gp_feats.append(ccip_extract_feature(local_imgfile)) + all_feats.extend(gp_feats) + + gp_diffs = ccip_batch_differences([*gp_feats, *ds_feats])[:len(gp_feats), len(gp_feats):] + gp_batch = gp_diffs <= ccip_default_threshold() + scores = gp_batch.mean(axis=1) + losses = gp_diffs.mean(axis=1) + + ret = { + 'model_name': m_name, + 'model_homepage': m_url, + 'images': images_count, + 'mean_score': scores.mean().item(), + 'median_score': np.median(scores).item(), + 'mean_loss': losses.mean().item(), + 'median_loss': np.median(losses).item(), + } + logging.info(f'Result of model: {ret!r}') + model_results.append(ret) + + all_diffs = ccip_batch_differences([*all_feats, *ds_feats])[:len(all_feats), len(all_feats):] + all_batch = all_diffs <= ccip_default_threshold() + all_scores = all_batch.mean(axis=1) + all_losses = all_diffs.mean(axis=1) + all_mean_score = all_scores.mean().item() + all_median_score = np.median(all_scores).item() + all_mean_loss = all_losses.mean().item() + all_median_loss = np.median(all_losses).item() + + if rating is not None: + logging.info('Making review ...') + with io.StringIO() as ds: + print('Tested on the following models:', file=ds) + print('', file=ds) + + all_total_images = 0 + for mr in model_results: + if mr['model_homepage']: + mx = f'[{mr["model_name"]}]({mr["model_homepage"]})' + else: + mx = mr['model_name'] + + all_total_images += mr['images'] + print( + f'When using model {mx}, {plural_word(mr["images"], "image")} in total, ' + f'recognition score (mean/median): {mr["mean_score"]:.3f}/{mr["median_score"]:.3f}, ' + f'character image loss (mean/median): {mr["mean_loss"]:.4f}/{mr["median_loss"]:.4f}.', + file=ds + ) + print('', file=ds) + + print( + f'Overall, {plural_word(all_total_images, "image")} in total, ' + f'recognition score (mean/median): {all_mean_score:.3f}/{all_median_score:.3f}, ' + f'character image loss (mean/median): {all_mean_loss:.4f}/{all_median_loss:.4f}.', + file=ds + ) + print('', file=ds) + + description_md = description_md or ds.getvalue() + + try: + civitai_review(model, model_version, model_creator, rating, description_md, session_repo) + except: + print('This is the description md:') + print(description_md) + raise diff --git a/cyberharem/infer/draw.py b/cyberharem/infer/draw.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc526918fa084b3f999d4c24246040cbb721e5f --- /dev/null +++ b/cyberharem/infer/draw.py @@ -0,0 +1,256 @@ +import glob +import io +import json +import logging +import os +import shutil +from dataclasses import dataclass +from textwrap import dedent +from typing import List, Union, Optional + +import yaml +from PIL.PngImagePlugin import PngInfo +from imgutils.detect import detect_censors + +try: + from yaml import CLoader as Loader, CDumper as Dumper +except ImportError: + from yaml import Loader, Dumper +from PIL import Image +from hbutils.system import TemporaryDirectory +from hcpdiff import Visualizer +from hcpdiff.utils import load_config_with_cli + +from ..utils import data_to_cli_args + +_DEFAULT_INFER_CFG_FILE = 'cfgs/infer/text2img_anime_lora.yaml' +_DEFAULT_INFER_MODEL = 'LittleApple-fp16/SpiritForeseerMix' + + +def sample_method_to_config(method): + if method == 'DPM++ SDE Karras': + return { + '_target_': 'diffusers.DPMSolverSDEScheduler', + 'beta_start': 0.00085, + 'beta_end': 0.012, + 'beta_schedule': 'scaled_linear', + 'use_karras_sigmas': True, + } + elif method == 'DPM++ 2M Karras': + return { + '_target_': 'diffusers.DPMSolverMultistepScheduler', + 'beta_start': 0.00085, + 'beta_end': 0.012, + 'algorithm_type': 'dpmsolver++', + 'beta_schedule': 'scaled_linear', + 'use_karras_sigmas': True + } + elif method == 'Euler a': + return { + '_target_': 'diffusers.EulerAncestralDiscreteScheduler', + 'beta_start': 0.00085, + 'beta_end': 0.012, + 'beta_schedule': 'scaled_linear', + } + else: + raise ValueError(f'Unknown sample method - {method!r}.') + + +def draw_images( + workdir: str, prompts: Union[str, List[str]], neg_prompts: Union[str, List[str]] = None, + seeds: Union[int, List[str]] = None, emb_name: str = None, save_cfg: bool = True, + model_steps: int = 1000, n_repeats: int = 2, pretrained_model: str = _DEFAULT_INFER_MODEL, + width: int = 512, height: int = 768, gscale: float = 8, infer_steps: int = 30, + lora_alpha: float = 0.85, output_dir: str = 'output', cfg_file: str = _DEFAULT_INFER_CFG_FILE, + clip_skip: int = 2, sample_method: str = 'DPM++ 2M Karras', +): + emb_name = emb_name or os.path.basename(workdir) + with TemporaryDirectory() as emb_dir: + src_pt_files = glob.glob(os.path.join(workdir, 'ckpts', f'*-{model_steps}.pt')) + if not src_pt_files: + raise FileNotFoundError(f'Embedding not found for step {model_steps}.') + + src_pt_file = src_pt_files[0] + shutil.copyfile(src_pt_file, os.path.join(emb_dir, f'{emb_name}.pt')) + + cli_args = data_to_cli_args({ + 'pretrained_model': pretrained_model, + 'N_repeats': n_repeats, + + 'vae_optimize': { + 'tiling': False, + }, + + 'clip_skip': clip_skip - 1, + + 'bs': 1, + 'num': 1, + + 'infer_args': { + 'width': width, + 'height': height, + 'guidance_scale': gscale, + 'num_inference_steps': infer_steps, + }, + + 'exp_dir': workdir, + 'model_steps': model_steps, + 'emb_dir': emb_dir, + 'output_dir': output_dir, + + 'merge': { + 'alpha': lora_alpha, + }, + + 'new_components': { + 'scheduler': sample_method_to_config(sample_method), + 'vae': { + '_target_': 'diffusers.AutoencoderKL.from_pretrained', + 'pretrained_model_name_or_path': 'deepghs/animefull-latest', # path to vae model + 'subfolder': 'vae', + } + } + }) + logging.info(f'Infer based on {cfg_file!r}, with {cli_args!r}') + cfgs = load_config_with_cli(cfg_file, args_list=cli_args) # skip --cfg + + N = None + if isinstance(prompts, list): + N = len(prompts) + if isinstance(neg_prompts, list): + if N is not None and len(neg_prompts) != N: + raise ValueError(f'Number of prompts ({len(prompts)}) and neg_prompts ({len(neg_prompts)}) not match.') + N = len(neg_prompts) + if isinstance(seeds, list): + if N is not None and len(seeds) != N: + raise ValueError(f'Number of both prompts ({N}) and seed ({len(seeds)}) not match.') + N = len(seeds) + + if N is None: + N = 1 + if not isinstance(prompts, list): + prompts = [prompts] * N + if not isinstance(neg_prompts, list): + neg_prompts = [neg_prompts] * N + if not isinstance(seeds, list): + seeds = [seeds] * N + + viser = Visualizer(cfgs) + viser.vis_to_dir(prompt=prompts, negative_prompt=neg_prompts, seeds=seeds, + save_cfg=save_cfg, **cfgs.infer_args) + + +@dataclass +class Drawing: + name: str + prompt: str + neg_prompt: str + seed: int + sfw: bool + width: int + height: int + gscale: float + steps: int + image: Image.Image + sample_method: str + clip_skip: int + model: str + model_hash: Optional[str] = None + + @property + def preview_info(self): + return dedent(f""" +Prompt: {self.prompt} +Neg Prompt: {self.neg_prompt} +Width: {self.width} +Height: {self.height} +Guidance Scale: {self.gscale} +Sample Method: {self.sample_method} +Infer Steps: {self.steps} +Clip Skip: {self.clip_skip} +Seed: {self.seed} +Model: {self.model} +Safe For Work: {"yes" if self.sfw else "no"} + """).lstrip() + + @property + def pnginfo_text(self) -> str: + with io.StringIO() as sf: + print(self.prompt, file=sf) + print(f'Negative prompt: {self.neg_prompt}', file=sf) + + if self.model_hash: + print(f'Steps: {self.steps}, Sampler: {self.sample_method}, ' + f'CFG scale: {self.gscale}, Seed: {self.seed}, Size: {self.width}x{self.height}, ' + f'Model hash: {self.model_hash.lower()}, Model: {self.model}, ' + f'Clip skip: {self.clip_skip}', file=sf) + else: + print(f'Steps: {self.steps}, Sampler: {self.sample_method}, ' + f'CFG scale: {self.gscale}, Seed: {self.seed}, Size: {self.width}x{self.height}, ' + f'Model: {self.model}, ' + f'Clip skip: {self.clip_skip}', file=sf) + + return sf.getvalue() + + @property + def pnginfo(self) -> PngInfo: + info = PngInfo() + info.add_text('parameters', self.pnginfo_text) + return info + + +_N_MAX_DRAW = 20 + + +def draw_with_workdir( + workdir: str, emb_name: str = None, save_cfg: bool = True, + model_steps: int = 1000, n_repeats: int = 2, pretrained_model: str = _DEFAULT_INFER_MODEL, + width: int = 512, height: int = 768, gscale: float = 8, infer_steps: int = 30, + lora_alpha: float = 0.85, output_dir: str = None, cfg_file: str = _DEFAULT_INFER_CFG_FILE, + clip_skip: int = 2, sample_method: str = 'DPM++ 2M Karras', model_hash: Optional[str] = None, +): + n_pnames, n_prompts, n_neg_prompts, n_seeds, n_sfws = [], [], [], [], [] + for jfile in glob.glob(os.path.join(workdir, 'rtags', '*.json')): + with open(jfile, 'r', encoding='utf-8') as f: + data = json.load(f) + n_pnames.append(data['name']) + n_prompts.append(data['prompt']) + n_neg_prompts.append(data['neg_prompt']) + n_seeds.append(data['seed']) + n_sfws.append(data['sfw']) + + n_total = len(n_pnames) + retval = [] + for x in range(0, n_total, _N_MAX_DRAW): + pnames, prompts, neg_prompts, seeds, sfws = \ + n_pnames[x:x + _N_MAX_DRAW], n_prompts[x:x + _N_MAX_DRAW], n_neg_prompts[x:x + _N_MAX_DRAW], \ + n_seeds[x:x + _N_MAX_DRAW], n_sfws[x:x + _N_MAX_DRAW] + + with TemporaryDirectory() as td: + _tmp_output_dir = output_dir or td + draw_images( + workdir, prompts, neg_prompts, seeds, + emb_name, save_cfg, model_steps, n_repeats, pretrained_model, + width, height, gscale, infer_steps, lora_alpha, _tmp_output_dir, cfg_file, + clip_skip, sample_method, + ) + + for i, (pname, prompt, neg_prompt, seed, sfw) in \ + enumerate(zip(pnames, prompts, neg_prompts, seeds, sfws), start=1): + img_file = glob.glob(os.path.join(_tmp_output_dir, f'{i}-*.png'))[0] + yaml_file = glob.glob(os.path.join(_tmp_output_dir, f'{i}-*.yaml'))[0] + with open(yaml_file, 'r', encoding='utf-8') as f: + seed = yaml.load(f, Loader)['seed'] + + img = Image.open(img_file) + img.load() + + retval.append(Drawing( + pname, prompt, neg_prompt, seed, + sfw=sfw and len(detect_censors(img, conf_threshold=0.45)) == 0, + width=width, height=height, gscale=gscale, steps=infer_steps, + image=img, sample_method=sample_method, clip_skip=clip_skip, + model=pretrained_model, model_hash=model_hash, + )) + + return retval diff --git a/cyberharem/infer/export.py b/cyberharem/infer/export.py new file mode 100644 index 0000000000000000000000000000000000000000..1a581c02d6340dd90aeb993c8419738119f4ca11 --- /dev/null +++ b/cyberharem/infer/export.py @@ -0,0 +1,101 @@ +import json +import logging +import os +from typing import Optional + +from hbutils.system import TemporaryDirectory +from huggingface_hub import hf_hub_url +from tqdm.auto import tqdm + +from .draw import _DEFAULT_INFER_MODEL, draw_with_workdir +from ..dataset import save_recommended_tags +from ..utils import get_hf_fs, download_file + + +def draw_to_directory(workdir: str, export_dir: str, step: int, n_repeats: int = 2, + pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2, + image_width: int = 512, image_height: int = 768, infer_steps: int = 30, + lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras', + model_hash: Optional[str] = None): + from ..publish.export import KNOWN_MODEL_HASHES + model_hash = model_hash or KNOWN_MODEL_HASHES.get(pretrained_model) + os.makedirs(export_dir, exist_ok=True) + + while True: + try: + drawings = draw_with_workdir( + workdir, model_steps=step, n_repeats=n_repeats, + pretrained_model=pretrained_model, + width=image_width, height=image_height, infer_steps=infer_steps, + lora_alpha=lora_alpha, clip_skip=clip_skip, sample_method=sample_method, + model_hash=model_hash, + ) + except RuntimeError: + n_repeats += 1 + else: + break + + all_image_files = [] + for draw in drawings: + img_file = os.path.join(export_dir, f'{draw.name}.png') + draw.image.save(img_file, pnginfo=draw.pnginfo) + all_image_files.append(img_file) + + with open(os.path.join(export_dir, f'{draw.name}_info.txt'), 'w', encoding='utf-8') as f: + print(draw.preview_info, file=f) + + +def draw_with_repo(repository: str, export_dir: str, step: Optional[int] = None, n_repeats: int = 2, + pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2, + image_width: int = 512, image_height: int = 768, infer_steps: int = 30, + lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras', + model_hash: Optional[str] = None): + from ..publish import find_steps_in_workdir + + hf_fs = get_hf_fs() + if not hf_fs.exists(f'{repository}/meta.json'): + raise ValueError(f'Invalid repository or no model found - {repository!r}.') + + logging.info(f'Model repository {repository!r} found.') + meta = json.loads(hf_fs.read_text(f'{repository}/meta.json')) + step = step or meta['best_step'] + logging.info(f'Using step {step} ...') + + with TemporaryDirectory() as workdir: + logging.info('Downloading models ...') + for f in tqdm(hf_fs.glob(f'{repository}/{step}/raw/*')): + rel_file = os.path.relpath(f, repository) + local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file)) + if os.path.dirname(local_file): + os.makedirs(os.path.dirname(local_file), exist_ok=True) + download_file( + hf_hub_url(repository, filename=rel_file), + local_file + ) + + logging.info(f'Regenerating tags for {workdir!r} ...') + pt_name, _ = find_steps_in_workdir(workdir) + game_name = pt_name.split('_')[-1] + name = '_'.join(pt_name.split('_')[:-1]) + + from gchar.games.dispatch.access import GAME_CHARS + if game_name in GAME_CHARS: + ch_cls = GAME_CHARS[game_name] + ch = ch_cls.get(name) + else: + ch = None + + if ch is None: + source = repository + else: + source = ch + + logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.') + save_recommended_tags(source, name=pt_name, workdir=workdir, ds_size=meta["dataset"]['type']) + + logging.info('Drawing ...') + draw_to_directory( + workdir, export_dir, step, + n_repeats, pretrained_model, clip_skip, image_width, image_height, infer_steps, + lora_alpha, sample_method, model_hash + ) diff --git a/cyberharem/list.py b/cyberharem/list.py new file mode 100644 index 0000000000000000000000000000000000000000..992353bc89c84464383dddf8987a8c0fb99982da --- /dev/null +++ b/cyberharem/list.py @@ -0,0 +1,43 @@ +import fnmatch +from functools import partial + +import click +from gchar.generic import import_generic +from gchar.utils import GLOBAL_CONTEXT_SETTINGS +from gchar.utils import print_version as _origin_print_version + +from cyberharem.utils import get_hf_client + +print_version = partial(_origin_print_version, 'cyberharem.train') + +import_generic() + + +@click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish trained models') +@click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True) +def cli(): + pass # pragma: no cover + + +@cli.command('models', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='List models') +@click.option('-p', '--pattern', 'pattern', type=str, default='*', + help='Pattern of models.', show_default=True) +def models(pattern): + hf_client = get_hf_client() + for model in hf_client.list_models(author='CyberHarem'): + if fnmatch.fnmatch(model.modelId, pattern): + print(model.modelId) + + +@cli.command('datasets', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='List datasets') +@click.option('-p', '--pattern', 'pattern', type=str, default='*', + help='Pattern of models.', show_default=True) +def datasets(pattern): + hf_client = get_hf_client() + for ds in hf_client.list_datasets(author='CyberHarem'): + if fnmatch.fnmatch(ds.id, pattern): + print(ds.id) + + +if __name__ == '__main__': + cli() diff --git a/cyberharem/publish/__init__.py b/cyberharem/publish/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b173b6effcecf744df1a5515d5d1df4978583c35 --- /dev/null +++ b/cyberharem/publish/__init__.py @@ -0,0 +1,6 @@ +from .civitai import civitai_query_model_tags, civitai_upsert_model, civitai_query_vae_models, civitai_create_version, \ + civitai_upload_models, civitai_get_model_info, civitai_upload_images, civiti_publish, civitai_publish_from_hf +from .convert import convert_to_webui_lora +from .export import export_workdir +from .huggingface import deploy_to_huggingface +from .steps import find_steps_in_workdir diff --git a/cyberharem/publish/__main__.py b/cyberharem/publish/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c0a1c55f15e15221ae168df6d7bd8b3627b92a2 --- /dev/null +++ b/cyberharem/publish/__main__.py @@ -0,0 +1,158 @@ +import os +from functools import partial + +import click +from ditk import logging +from gchar.generic import import_generic +from gchar.utils import GLOBAL_CONTEXT_SETTINGS +from gchar.utils import print_version as _origin_print_version +from hbutils.system import TemporaryDirectory +from huggingface_hub import hf_hub_url +from tqdm.auto import tqdm + +from cyberharem.dataset import save_recommended_tags +from cyberharem.publish import find_steps_in_workdir +from cyberharem.utils import get_hf_fs, download_file +from .civitai import civitai_publish_from_hf +from .huggingface import deploy_to_huggingface +from ..infer.draw import _DEFAULT_INFER_MODEL + +import_generic() + +print_version = partial(_origin_print_version, 'cyberharem') + + +@click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish trained models') +@click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True) +def cli(): + pass # pragma: no cover + + +@cli.command('huggingface', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish to huggingface') +@click.option('-w', '--workdir', 'workdir', type=click.Path(file_okay=False, exists=True), required=True, + help='Work directory for experiment.', show_default=True) +@click.option('--repository', '-r', 'repository', type=str, default=None, + help='Repository to publish to.', show_default=True) +@click.option('--revision', '-R', 'revision', type=str, default='main', + help='Revision for pushing the model.', show_default=True) +@click.option('-n', '--n_repeats', 'n_repeats', type=int, default=3, + help='N Repeats for text encoder', show_default=True) +@click.option('-m', '--pretrained_model', 'pretrained_model', type=str, default=_DEFAULT_INFER_MODEL, + help='Pretrained model for preview drawing.', show_default=True) +@click.option('--width', 'width', type=int, default=512, + help='Width of images.', show_default=True) +@click.option('--height', 'height', type=int, default=768, + help='Height of images.', show_default=True) +@click.option('-C', '--clip_skip', 'clip_skip', type=int, default=2, + help='Clip skip.', show_default=True) +@click.option('-S', '--infer_steps', 'infer_steps', type=int, default=30, + help='Steps of inference.', show_default=True) +def huggingface(workdir: str, repository, revision, n_repeats, pretrained_model, + width, height, clip_skip, infer_steps): + logging.try_init_root(logging.INFO) + deploy_to_huggingface( + workdir, repository, revision, n_repeats, pretrained_model, + clip_skip, width, height, infer_steps, + ) + + +@cli.command('rehf', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Re-Publish to huggingface') +@click.option('--repository', '-r', 'repository', type=str, default=None, + help='Repository to publish to.', show_default=True) +@click.option('--revision', '-R', 'revision', type=str, default='main', + help='Revision for pushing the model.', show_default=True) +@click.option('-n', '--n_repeats', 'n_repeats', type=int, default=3, + help='N Repeats for text encoder', show_default=True) +@click.option('-m', '--pretrained_model', 'pretrained_model', type=str, default=_DEFAULT_INFER_MODEL, + help='Pretrained model for preview drawing.', show_default=True) +@click.option('--width', 'width', type=int, default=512, + help='Width of images.', show_default=True) +@click.option('--height', 'height', type=int, default=768, + help='Height of images.', show_default=True) +@click.option('-C', '--clip_skip', 'clip_skip', type=int, default=2, + help='Clip skip.', show_default=True) +@click.option('-S', '--infer_steps', 'infer_steps', type=int, default=30, + help='Steps of inference.', show_default=True) +def rehf(repository, revision, n_repeats, pretrained_model, + width, height, clip_skip, infer_steps): + logging.try_init_root(logging.INFO) + with TemporaryDirectory() as workdir: + logging.info(f'Downloading models for {workdir!r} ...') + hf_fs = get_hf_fs() + for f in tqdm(hf_fs.glob(f'{repository}/*/raw/*')): + rel_file = os.path.relpath(f, repository) + local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file)) + if os.path.dirname(local_file): + os.makedirs(os.path.dirname(local_file), exist_ok=True) + download_file( + hf_hub_url(repository, filename=rel_file), + local_file + ) + + logging.info(f'Regenerating tags for {workdir!r} ...') + pt_name, _ = find_steps_in_workdir(workdir) + game_name = pt_name.split('_')[-1] + name = '_'.join(pt_name.split('_')[:-1]) + + from gchar.games.dispatch.access import GAME_CHARS + if game_name in GAME_CHARS: + ch_cls = GAME_CHARS[game_name] + ch = ch_cls.get(name) + else: + ch = None + + if ch is None: + source = repository + else: + source = ch + + logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.') + save_recommended_tags(source, name=pt_name, workdir=workdir) + logging.info('Success!') + + deploy_to_huggingface( + workdir, repository, revision, n_repeats, pretrained_model, + clip_skip, width, height, infer_steps, + ) + + +@cli.command('civitai', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish to huggingface') +@click.option('--repository', '-r', 'repository', type=str, required=True, + help='Repository to publish from.', show_default=True) +@click.option('--title', '-t', 'title', type=str, default=None, + help='Title of the civitai model.', show_default=True) +@click.option('--steps', '-s', 'steps', type=int, default=None, + help='Steps to deploy.', show_default=True) +@click.option('--epochs', '-e', 'epochs', type=int, default=None, + help='Epochs to deploy.', show_default=True) +@click.option('--draft', '-d', 'draft', is_flag=True, type=bool, default=False, + help='Only create draft without publishing.', show_default=True) +@click.option('--time', '-T', 'publish_time', type=str, default=None, + help='Publish time, publish immediately when not given.', show_default=True) +@click.option('--allow_nsfw', '-N', 'allow_nsfw', is_flag=True, type=bool, default=False, + help='Allow uploading nsfw images.', show_default=True) +@click.option('--version_name', '-v', 'version_name', type=str, default=None, + help='Name of the version.', show_default=True) +@click.option('--force_create', '-F', 'force_create', is_flag=True, type=bool, default=False, + help='Force create new model.', show_default=True) +@click.option('--no_ccip', 'no_ccip_check', is_flag=True, type=bool, default=False, + help='No CCIP check.', show_default=True) +def civitai(repository, title, steps, epochs, draft, publish_time, allow_nsfw, + version_name, force_create, no_ccip_check): + logging.try_init_root(logging.INFO) + model_id = civitai_publish_from_hf( + repository, title, + step=steps, epoch=epochs, draft=draft, + publish_at=publish_time, allow_nsfw_images=allow_nsfw, + version_name=version_name, force_create_model=force_create, + no_ccip_check=no_ccip_check, + ) + url = f'https://civitai.com/models/{model_id}' + if not draft: + logging.info(f'Deploy success, model now can be seen at {url} .') + else: + logging.info(f'Draft created, it can be seed at {url} .') + + +if __name__ == '__main__': + cli() diff --git a/cyberharem/publish/__pycache__/__init__.cpython-310.pyc b/cyberharem/publish/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a24b6b5206d011650916aad5b3cd0482da42003 Binary files /dev/null and b/cyberharem/publish/__pycache__/__init__.cpython-310.pyc differ diff --git a/cyberharem/publish/__pycache__/__main__.cpython-310.pyc b/cyberharem/publish/__pycache__/__main__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea9123dfa9bc4b5bbaf8d1e0ca7e3addf698f4f2 Binary files /dev/null and b/cyberharem/publish/__pycache__/__main__.cpython-310.pyc differ diff --git a/cyberharem/publish/__pycache__/civitai.cpython-310.pyc b/cyberharem/publish/__pycache__/civitai.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a483c5d99c728f2c84f32fda1aa8816f20d54b30 Binary files /dev/null and b/cyberharem/publish/__pycache__/civitai.cpython-310.pyc differ diff --git a/cyberharem/publish/__pycache__/convert.cpython-310.pyc b/cyberharem/publish/__pycache__/convert.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..137023284a14e47630dbfb1780d5843a7ddb86aa Binary files /dev/null and b/cyberharem/publish/__pycache__/convert.cpython-310.pyc differ diff --git a/cyberharem/publish/__pycache__/export.cpython-310.pyc b/cyberharem/publish/__pycache__/export.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f24426e5ab7d345c476d4d62f7b5411e8130aa3 Binary files /dev/null and b/cyberharem/publish/__pycache__/export.cpython-310.pyc differ diff --git a/cyberharem/publish/__pycache__/huggingface.cpython-310.pyc b/cyberharem/publish/__pycache__/huggingface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e924ec7c3f016b562173afbc89b3ec9c985f1718 Binary files /dev/null and b/cyberharem/publish/__pycache__/huggingface.cpython-310.pyc differ diff --git a/cyberharem/publish/__pycache__/steps.cpython-310.pyc b/cyberharem/publish/__pycache__/steps.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15103cc6ba4c0d3e189fb9339774ced183f859af Binary files /dev/null and b/cyberharem/publish/__pycache__/steps.cpython-310.pyc differ diff --git a/cyberharem/publish/civitai.py b/cyberharem/publish/civitai.py new file mode 100644 index 0000000000000000000000000000000000000000..b930d0c55f3d87eeac41b085f07f8eee1bbbc678 --- /dev/null +++ b/cyberharem/publish/civitai.py @@ -0,0 +1,915 @@ +import glob +import json +import logging +import math +import os.path +import re +import textwrap +import uuid +from typing import Optional, Tuple, List, Union + +import blurhash +import numpy as np +from PIL import Image +from gchar.games.base import Character +from gchar.games.dispatch.access import GAME_CHARS +from gchar.generic import import_generic +from hbutils.string import plural_word +from hbutils.system import TemporaryDirectory +from huggingface_hub import hf_hub_url +from imgutils.data import load_image +from imgutils.detect import detect_faces +from imgutils.metrics import ccip_extract_feature, ccip_batch_same +from imgutils.validate import anime_rating_score, nsfw_pred +from pycivitai import civitai_find_online +from pycivitai.client import ModelNotFound +from tqdm.auto import tqdm +from urlobject import URLObject +from waifuc.source import LocalSource + +try: + from typing import Literal +except (ModuleNotFoundError, ImportError): + from typing_extensions import Literal + +import markdown2 + +from ..dataset import load_dataset_for_character +from ..utils import get_civitai_session, srequest, get_ch_name, get_hf_fs, download_file, parse_time, \ + load_tags_from_directory, repr_tags + +import_generic() + + +def _norm(x, keep_space: bool = True): + return re.sub(r'[\W_]+', ' ' if keep_space else '', x.lower()).strip() + + +def _model_tag_same(x, y): + return _norm(x, keep_space=True) == _norm(y, keep_space=True) + + +def civitai_query_model_tags(tag: str, session=None) -> Tuple[Optional[int], str]: + session = session or get_civitai_session() + logging.info(f'Querying tag {tag!r} from civitai ...') + resp = srequest(session, 'GET', 'https://civitai.com/api/trpc/tag.getAll', params={ + 'input': json.dumps({ + "json": { + "limit": 20, + "entityType": ["Model"], + "categories": False, + "query": tag, + "authed": True, + } + }) + }, headers={'Referer': 'https://civitai.com/models/create'}) + + data = resp.json()['result']['data']['json']['items'] + for item in data: + if _model_tag_same(item['name'], tag): + logging.info(f'Tag {item["name"]}({item["id"]}) found on civitai.') + return item['id'], item['name'] + else: + logging.info(f'Tag not found on civitai, new tag {_norm(tag)!r} will be created.') + return None, _norm(tag) + + +CommercialUseTyping = Literal['none', 'image', 'rentCivit', 'rent', 'sell'] + + +def civitai_upsert_model( + name, description_md: str, tags: List[str], + commercial_use: CommercialUseTyping = 'rent', + allow_no_credit: bool = True, allow_derivatives: bool = True, allow_different_licence: bool = True, + nsfw: bool = False, poi: bool = False, exist_model_id: Optional[int] = None, + session=None +) -> Tuple[int, bool]: + session = session or get_civitai_session() + _exist_tags, tag_list, _tag_id = set(), [], 0 + _meta_values = {} + for tag in tags: + tag_id, tag_name = civitai_query_model_tags(tag, session) + if tag_name not in _exist_tags: + tag_list.append({'id': tag_id, 'name': tag_name}) + _meta_values[f"tagsOnModels.{_tag_id}.id"] = ["undefined"] + _tag_id += 1 + + post_json = { + "name": name, + "description": markdown2.markdown(textwrap.dedent(description_md)), + "type": "LORA", + + "allowCommercialUse": commercial_use.lower().capitalize(), # None, Image, Rent, Sell + "allowNoCredit": allow_no_credit, + "allowDerivatives": allow_derivatives, + "allowDifferentLicense": allow_different_licence, + + "nsfw": nsfw, + "poi": poi, + "tagsOnModels": tag_list, + + "authed": True, + "status": "Draft", + "checkpointType": None, + "uploadType": "Created", + } + if exist_model_id: + post_json['id'] = exist_model_id + post_json["locked"] = False + post_json["status"] = "Published" + logging.info(f'Model {name!r}({exist_model_id}) already exist, updating its new information. ' + f'Tags: {[item["name"] for item in tag_list]!r} ...') + else: + logging.info(f'Creating model {name!r}, tags: {[item["name"] for item in tag_list]!r} ...') + + resp = session.post( + 'https://civitai.com/api/trpc/model.upsert', + json={ + "json": post_json, + "meta": { + "values": _meta_values, + } + }, + headers={'Referer': 'https://civitai.com/models/create'}, + ) + + data = resp.json()['result']['data']['json'] + return data['id'], data['nsfw'] + + +def civitai_query_vae_models(session=None, model_id=None): + session = session or get_civitai_session() + logging.info('Querying VAE models ...') + resp = srequest( + session, 'GET', ' https://civitai.com/api/trpc/modelVersion.getModelVersionsByModelType', + params={'input': json.dumps({"json": {"type": "VAE", "authed": True}})}, + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=2'} + ) + + data = resp.json()['result']['data']['json'] + logging.info(f'{plural_word(len(data), "VAE model")} found.') + return data + + +def _vae_model_same(x, y): + return _norm(x, keep_space=False) == _norm(y, keep_space=False) + + +def civitai_create_version( + model_id: int, version_name: str, description_md: str, trigger_words: List[str], + base_model: str = 'SD 1.5', steps: Optional[int] = None, epochs: Optional[int] = None, + clip_skip: Optional[int] = 2, vae_name: Optional[str] = None, early_access_time: int = 0, + session=None +): + session = session or get_civitai_session() + + vae_id = None + if vae_name: + for vae_item in civitai_query_vae_models(session, model_id): + if _vae_model_same(vae_item['modelName'], vae_name): + vae_id = vae_item['id'] + + logging.info(f'Creating version {version_name!r} for model {model_id}, with base model {base_model!r} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/modelVersion.upsert', + json={ + "json": { + "modelId": model_id, + "name": version_name, + "baseModel": base_model, + "description": markdown2.markdown(textwrap.dedent(description_md)), + "steps": steps, + "epochs": epochs, + "clipSkip": clip_skip, + "vaeId": vae_id, + "trainedWords": trigger_words, + "earlyAccessTimeFrame": early_access_time, + "skipTrainedWords": bool(not trigger_words), + "authed": True, + } + }, + headers={'Referer': f'https://civitai.com/models/{model_id}/wizard?step=2'} + ) + + return resp.json()['result']['data']['json'] + + +def civitai_upload_file(local_file: str, type_: str = 'model', filename: str = None, + model_id: int = None, session=None): + session = session or get_civitai_session() + filename = filename or os.path.basename(local_file) + + logging.info(f'Creating uploading request for {filename!r} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/upload', + json={ + "filename": filename, + "type": type_, + "size": os.path.getsize(local_file), + }, + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=3'} + ) + upload_data = resp.json() + + logging.info(f'Uploading file {local_file!r} as {filename!r} ...') + with open(local_file, 'rb') as f: + resp = srequest( + session, 'PUT', upload_data['urls'][0]['url'], data=f, + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=3'}, + ) + etag = resp.headers['ETag'] + + logging.info(f'Completing uploading for {filename!r} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/upload/complete', + json={ + "bucket": upload_data['bucket'], + "key": upload_data['key'], + "type": type_, + "uploadId": upload_data['uploadId'], + "parts": [{"ETag": etag, "PartNumber": 1}], + }, + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=3'}, + ) + resp.raise_for_status() + + return { + "url": str(URLObject(upload_data['urls'][0]['url']).without_query()), + "bucket": upload_data['bucket'], + "key": upload_data['key'], + "name": filename, + "uuid": str(uuid.uuid4()), + "sizeKB": os.path.getsize(local_file) / 1024.0, + } + + +def civitai_upload_models(model_version_id: int, model_files: List[Union[str, Tuple[str, str]]], + model_id: int = None, session=None): + session = session or get_civitai_session() + file_items = [] + for file_item in model_files: + if isinstance(file_item, str): + local_file, filename = file_item, file_item + elif isinstance(file_item, tuple): + local_file, filename = file_item + else: + raise TypeError(f'Unknown file type - {file_item!r}.') + file_items.append((local_file, filename)) + + for local_file, filename in file_items: + upload_data = civitai_upload_file(local_file, 'model', filename, model_id, session) + logging.info(f'Creating {filename!r} as model file of version {model_version_id} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/modelFile.create', + json={ + 'json': { + **upload_data, + "modelVersionId": model_version_id, + "type": "Model", + "metadata": { + "size": None, + "fp": None + }, + "authed": True + }, + }, + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=3'}, + ) + resp.raise_for_status() + + +def civitai_get_model_info(model_id: int, session=None): + session = session or get_civitai_session() + resp = srequest( + session, 'GET', 'https://civitai.com/api/trpc/model.getById', + params={'input': json.dumps({"json": {"id": model_id, "authed": True}})}, + headers={'Referer': f'https://civitai.com/models/{model_id}/wizard?step=4'}, + ) + return resp.json()['result']['data']['json'] + + +def get_clamped_size(width, height, max_val, _type='all'): + if _type == 'all': + if width >= height: + _type = 'width' + elif height >= width: + _type = 'height' + + if _type == 'width' and width > max_val: + return max_val, int(round((height / width) * max_val)) + + if _type == 'height' and height > max_val: + return int(round((width / height) * max_val)), max_val + + return width, height + + +def parse_publish_at(publish_at: Optional[str] = None, keep_none: bool = True) -> Optional[str]: + try: + from zoneinfo import ZoneInfo + except (ImportError, ModuleNotFoundError): + from backports.zoneinfo import ZoneInfo + + if not keep_none and publish_at is None: + publish_at = 'now' + if publish_at is not None: + local_time = parse_time(publish_at) + publish_at = local_time.astimezone(ZoneInfo('UTC')).isoformat() + + return publish_at + + +def _post_create_func(model_version_id): + return { + "json": { + "modelVersionId": model_version_id, + "authed": True, + } + } + + +def civitai_upload_images( + model_version_id: int, image_files: List[Union[str, Tuple[str, str], Tuple[str, str, dict]]], + tags: List[str], nsfw: bool = False, model_id: int = None, pc_func=_post_create_func, session=None +): + session = session or get_civitai_session() + + image_items = [] + for image_item in image_files: + if isinstance(image_item, str): + local_file, filename, meta = image_item, image_item, {} + elif isinstance(image_item, tuple): + if len(image_item) == 2: + (local_file, filename), meta = image_item, {} + elif len(image_item) == 3: + local_file, filename, meta = image_item + else: + raise ValueError(f'Invalid image file format - {image_item!r}.') + else: + raise TypeError(f'Invalid image file type - {image_item!r}.') + image_items.append((local_file, filename, meta)) + + logging.info(f'Creating post for model version {model_version_id} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/post.create', + json=pc_func(model_version_id), + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'}, + ) + post_id = resp.json()['result']['data']['json']['id'] + + for index, (local_file, filename, meta) in enumerate(image_items): + logging.info(f'Creating image uploading request for image {filename!r} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/image-upload', + json={ + "filename": filename, + "metadata": {} + }, + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'}, + ) + upload_id = resp.json()['id'] + upload_url = resp.json()['uploadURL'] + + logging.info(f'Uploading local image {local_file!r} as image {filename!r} ...') + with open(local_file, 'rb') as f: + resp = srequest(session, 'PUT', upload_url, data=f) + resp.raise_for_status() + + img = load_image(local_file, force_background='white', mode='RGB') + new_width, new_height = get_clamped_size(img.width, img.height, 32) + bhash = blurhash.encode(np.array(img.resize((new_width, new_height)))) + logging.info(f'Completing the uploading of {filename!r} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/post.addImage', + json={ + "json": { + "type": "image", + "index": index, + "uuid": str(uuid.uuid4()), + "name": filename, + "meta": meta, + "url": upload_id, + "mimeType": "image/png", + "hash": bhash, + "width": img.width, + "height": img.height, + "status": "uploading", + "message": None, + "postId": post_id, + "modelVersionId": model_version_id, + "authed": True + }, + "meta": { + "values": { + "message": [ + "undefined" + ] + } + } + }, + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'}, + ) + resp.raise_for_status() + + for tag in tags: + tag_id, tag_name = civitai_query_model_tags(tag, session) + if tag_id is not None: + logging.info(f'Adding tag {tag_name!r}({tag_id}) for post {post_id!r} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/post.addTag', + json={ + "json": { + "id": post_id, + "tagId": tag_id, + "name": tag_name, + "authed": True, + } + }, + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'}, + ) + else: + logging.info(f'Creating and adding new tag {tag_name!r} for post {post_id!r} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/post.addTag', + json={ + "json": { + "id": post_id, + "tagId": None, + "name": tag_name, + "authed": True, + }, + "meta": { + "values": { + "tagId": ["undefined"] + } + } + }, + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'}, + ) + + resp.raise_for_status() + + logging.info(f'Marking for nsfw ({nsfw!r}) ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/post.update', + json={ + 'json': { + 'id': post_id, + 'nsfw': nsfw, + 'authed': True, + } + }, + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'}, + ) + resp.raise_for_status() + + return post_id + + +def civiti_publish(model_id: int, model_version_id: int, publish_at=None, session=None): + session = session or get_civitai_session() + publish_at = parse_publish_at(publish_at, keep_none=True) + + if publish_at: + logging.info(f'Publishing model {model_id!r}\'s version {model_version_id!r}, at {publish_at!r} ...') + else: + logging.info(f'Publishing model {model_id!r}\'s version {model_version_id!r} ...') + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/model.publish', + json={ + "json": { + "id": model_id, + "versionIds": [ + model_version_id + ], + "publishedAt": publish_at, + "authed": True + }, + "meta": { + "values": { + "publishedAt": [ + "undefined" if publish_at is None else "Date", + ] + } + } + }, + headers={'Referer': f'https://civitai.com/models/{model_id or 0}/wizard?step=4'}, + ) + resp.raise_for_status() + + +def try_find_title(char_name, game_name): + try: + game_cls = GAME_CHARS[game_name.lower()] + ch = game_cls.get(char_name) + if ch: + names = [] + if ch.enname: + names.append(str(ch.enname)) + if ch.jpname: + names.append(str(ch.jpname)) + if ch.cnname: + names.append(str(ch.cnname)) + if hasattr(ch, 'krname') and ch.krname: + names.append(str(ch.krname)) + + return f"{'/'.join(names)} ({game_cls.__official_name__})" + + else: + cname = ' '.join(list(map(str.capitalize, char_name.split(' ')))) + return f'{cname} ({game_cls.__official_name__})' + + except KeyError: + return None + + +def try_get_title_from_repo(repo): + hf_fs = get_hf_fs() + print(f'datasets/{repo}/meta.json') + if hf_fs.exists(f'datasets/{repo}/meta.json'): + data = json.loads(hf_fs.read_text(f'datasets/{repo}/meta.json')) + character_name = data['name'] + + source_name = repo.split('_')[-1] + if hf_fs.exists(f'datasets/BangumiBase/{source_name}/meta.json'): + base_data = json.loads(hf_fs.read_text(f'datasets/BangumiBase/{source_name}/meta.json')) + source_full_name = base_data['name'] + return f'{character_name} ({source_full_name})' + else: + return character_name + else: + return None + + +def _tag_decode(text): + return re.sub(r'[\s_]+', ' ', re.sub(r'\\([\\()])', r'\1', text)).strip() + + +def civitai_publish_from_hf(source, model_name: str = None, model_desc_md: str = None, + version_name: Optional[str] = None, version_desc_md: str = None, + step: Optional[int] = None, epoch: Optional[int] = None, upload_min_epoch: int = 6, + draft: bool = False, publish_at=None, allow_nsfw_images: bool = True, + force_create_model: bool = False, no_ccip_check: bool = False, session=None): + if isinstance(source, Character): + repo = f'AppleHarem/{get_ch_name(source)}' + elif isinstance(source, str): + repo = source + else: + raise TypeError(f'Unknown source type - {source!r}.') + hf_fs = get_hf_fs() + meta_json = json.loads(hf_fs.read_text(f'{repo}/meta.json')) + game_name = repo.split('_')[-1] + + dataset_info = meta_json.get('dataset') + ds_size = (384, 512) if not dataset_info or not dataset_info['type'] else dataset_info['type'] + with load_dataset_for_character(repo, size=ds_size) as (_, d): + if dataset_info and dataset_info['size']: + dataset_size = dataset_info['size'] + else: + dataset_size = len(glob.glob(os.path.join(d, '*.png'))) + core_tags, _ = load_tags_from_directory(d) + logging.info(f'Size of dataset if {dataset_size!r}.') + + ccip_feats = [] + for item in tqdm(list(LocalSource(d)[:10]), desc='Extracting features'): + ccip_feats.append(ccip_extract_feature(item.image)) + + version_name = version_name or meta_json.get('mark') or 'v1.0' + all_steps = meta_json['steps'] + logging.info(f'Available steps: {all_steps!r}.') + if step is not None: + if epoch is not None: + logging.warning(f'Step {step!r} is set, epoch value ({epoch}) will be ignored.') + else: + if epoch is not None: + step = dataset_size * epoch + else: + if 'best_step' in meta_json: + if upload_min_epoch is not None: + upload_min_step = upload_min_epoch * dataset_size + else: + upload_min_step = -1 + best_step, best_score = None, None + for score_item in meta_json["scores"]: + if best_step is None or \ + (score_item['step'] >= upload_min_step and score_item['score'] >= best_score): + best_step, best_score = score_item['step'], score_item['score'] + + if best_step is not None: + step = best_step + else: + step = meta_json['best_step'] + else: + step = max(all_steps) + + logging.info(f'Expected step is {step!r}.') + _, _actual_step = sorted([(abs(s - step), s) for s in all_steps])[0] + if _actual_step != step: + logging.info(f'Actual used step is {_actual_step!r}.') + + step = _actual_step + epoch = int(math.ceil(step / dataset_size)) + logging.info(f'Using step {step}, epoch {epoch}.') + + with TemporaryDirectory() as td: + models_dir = os.path.join(td, 'models') + os.makedirs(models_dir, exist_ok=True) + + lora_file = os.path.basename(hf_fs.glob(f'{repo}/{step}/*.safetensors')[0]) + pt_file = os.path.basename(hf_fs.glob(f'{repo}/{step}/*.pt')[0]) + trigger_word = os.path.splitext(lora_file)[0] + char_name = ' '.join(trigger_word.split('_')[:-1]) + + models = [] + local_lora_file = os.path.join(models_dir, lora_file) + download_file(hf_hub_url(repo, filename=f'{step}/{lora_file}'), local_lora_file) + models.append((local_lora_file, lora_file)) + local_pt_file = os.path.join(models_dir, pt_file) + download_file(hf_hub_url(repo, filename=f'{step}/{pt_file}'), local_pt_file) + models.append((local_pt_file, pt_file)) + + images_dir = os.path.join(td, 'images') + os.makedirs(images_dir, exist_ok=True) + + images = [] + tags_count = {} + tags_idx = {} + for img_file in hf_fs.glob(f'{repo}/{step}/previews/*.png'): + img_filename = os.path.basename(img_file) + img_name = os.path.splitext(img_filename)[0] + img_info_filename = f'{img_name}_info.txt' + + local_img_file = os.path.join(images_dir, img_filename) + download_file(hf_hub_url(repo, filename=f'{step}/previews/{img_filename}'), local_img_file) + local_info_file = os.path.join(images_dir, img_info_filename) + download_file(hf_hub_url(repo, filename=f'{step}/previews/{img_info_filename}'), local_info_file) + + info = {} + with open(local_info_file, 'r', encoding='utf-8') as iif: + for line in iif: + line = line.strip() + if line: + info_name, info_text = line.split(':', maxsplit=1) + info[info_name.strip()] = info_text.strip() + + meta = { + 'cfgScale': int(round(float(info.get('Guidance Scale')))), + 'negativePrompt': info.get('Neg Prompt'), + 'prompt': info.get('Prompt'), + 'sampler': info.get('Sample Method', "Euler a"), + 'seed': int(info.get('Seed')), + 'steps': int(info.get('Infer Steps')), + 'Size': f"{info['Width']}x{info['Height']}", + } + if info.get('Clip Skip'): + meta['clipSkip'] = int(info['Clip Skip']) + if info.get('Model'): + meta['Model'] = info['Model'] + pil_img_file = Image.open(local_img_file) + if pil_img_file.info.get('parameters'): + png_info_text = pil_img_file.info['parameters'] + find_hash = re.findall(r'Model hash:\s*([a-zA-Z\d]+)', png_info_text, re.IGNORECASE) + if find_hash: + model_hash = find_hash[0].lower() + meta['hashes'] = {"model": model_hash} + meta["resources"] = [ + { + "hash": model_hash, + "name": info['Model'], + "type": "model" + } + ] + meta["Model hash"] = model_hash + + nsfw = (info.get('Safe For Word', info.get('Safe For Work')) or '').lower() != 'yes' + if not nsfw: + cls_, score_ = nsfw_pred(local_img_file) + if cls_ not in {'hentai', 'porn', 'sexy'} and score_ >= 0.65: + pass + else: + nsfw = True + + if nsfw and not allow_nsfw_images: + logging.info(f'Image {local_img_file!r} skipped due to its nsfw.') + continue + + current_feat = ccip_extract_feature(local_img_file) + similarity = ccip_batch_same([current_feat, *ccip_feats])[0, 1:].mean() + logging.info(f'Similarity of character on image {local_img_file!r}: {similarity!r}') + if similarity < 0.6 and not no_ccip_check: + logging.info(f'Similarity of {local_img_file!r}({similarity!r}) is too low, skipped.') + continue + + if not nsfw or allow_nsfw_images: + rating_score = anime_rating_score(local_img_file) + safe_v = int(round(rating_score['safe'] * 10)) + safe_r15 = int(round(rating_score['r15'] * 10)) + safe_r18 = int(round(rating_score['r18'] * 10)) + faces = detect_faces(local_img_file) + if faces: + if len(faces) > 1: + logging.warning('Multiple face detected, skipped!') + continue + + (x0, y0, x1, y1), _, _ = faces[0] + width, height = load_image(local_img_file).size + face_area = abs((x1 - x0) * (y1 - y0)) + face_ratio = face_area * 1.0 / (width * height) + face_ratio = int(round(face_ratio * 50)) + else: + logging.warning('No face detected, skipped!') + continue + + images.append(( + (-safe_v, -safe_r15, -safe_r18) if False else 0, + -face_ratio, + 1 if nsfw else 0, + 0 if img_name.startswith('pattern_') else 1, + img_name, + (local_img_file, img_filename, meta) + )) + + for ptag in info.get('Prompt').split(','): + ptag = ptag.strip() + tags_count[ptag] = tags_count.get(ptag, 0) + 1 + if ptag not in tags_idx: + tags_idx[ptag] = len(tags_idx) + + images = [item[-1] for item in sorted(images)] + max_tag_cnt = max(tags_count.values()) + recommended_tags = sorted([ptag for ptag, cnt in tags_count.items() if cnt == max_tag_cnt], + key=lambda x: tags_idx[x]) + + # publish model + session = session or get_civitai_session(timeout=30) + + model_desc_default = f""" + * Thanks to Civitai's TOS, some images cannot be uploaded. **THE FULL PREVIEW IMAGES CAN BE FOUND ON [HUGGINGFACE](https://huggingface.co/{repo})**. + * **THIS MODEL HAS TWO FILES. YOU NEED TO USE THEM TOGETHER!!!** + * **The associated trigger words are only for reference, it may need to be adjusted at some times**. + * Recommended weight of pt file is 0.5-1.0, weight of LoRA is 0.5-0.85. + * Images were generated using a few fixed prompts and dataset-based clustered prompts. Random seeds were used, ruling out cherry-picking. **What you see here is what you can get.** + * No specialized training was done for outfits. You can check our provided preview post to get the prompts corresponding to the outfits. + * This model is trained with **{plural_word(dataset_size, "image")}**. + + ## How to Use This Model + + **THIS MODEL HAS TWO FILES. YOU NEED TO USE THEM TOGETHER!!!**. + In this case, you need to download both `{pt_file}` and + `{lora_file}`, then **use `{pt_file}` as texture inversion embedding, and use + `{lora_file}` as LoRA at the same time**. + + **このモデルには2つのファイルがあります。一緒に使う必要があります!!!**。 + この場合、`{pt_file}`と`{lora_file}`の両方をダウンロード + する必要があります。`{pt_file}`をテクスチャ反転埋め込みとして使用し、同時に`{lora_file}`をLoRAとして使用してください。 + + **这个模型有两个文件。你需要同时使用它们!!!**。 + 在这种情况下,您需要下载`{pt_file}`和`{lora_file}`这两个文件,然后将`{pt_file}`用作纹理反转嵌入, + 同时使用`{lora_file}`作为LoRA。 + + **이 모델은 두 개의 파일이 있습니다. 두 파일을 함께 사용해야 합니다!!!**. + 이 경우에는 `{pt_file}`와 `{lora_file}` 두 파일을 모두 다운로드하신 다음에 **`{pt_file}`을 텍스처 반전 임베딩으로 사용하고, + 동시에 `{lora_file}`을 LoRA로 사용하셔야 합니다**. + + (Translated with ChatGPT) + + The trigger word is `{trigger_word}`, and the recommended tags are `{', '.join(recommended_tags)}`. + + ## How This Model Is Trained + + This model is trained with [HCP-Diffusion](https://github.com/7eu7d7/HCP-Diffusion). + And the auto-training framework is maintained by [DeepGHS Team](https://huggingface.co/deepghs). + And the WebUI Panel provid by [LittleAppleWebUI](https://github.com/LittleApple-fp16/LittleAppleWebUI) + + ## Why Some Preview Images Not Look Like {" ".join(map(str.capitalize, trigger_word.split("_")))} + + **All the prompt texts** used on the preview images (which can be viewed by clicking on the images) + **are automatically generated using clustering algorithms** based on feature information extracted from the + training dataset. The seed used during image generation is also randomly generated, and the images have + not undergone any selection or modification. As a result, there is a possibility of the mentioned + issues occurring. + + In practice, based on our internal testing, most models that experience such issues perform better in + actual usage than what is seen in the preview images. **The only thing you may need to do is adjusting + the tags you are using**. + + ## I Felt This Model May Be Overfitting or Underfitting, What Shall I Do + + Our model has been published on [huggingface repository - {repo}](https://huggingface.co/{repo}), where + models of all the steps are saved. Also, we published the training dataset on + [huggingface dataset - {repo}](https://huggingface.co/datasets/{repo}), which may be helpful to you. + + ## Why Not Just Using The Better-Selected Images + + Our model's entire process, from data crawling, training, to generating preview images and publishing, + is **100% automated without any human intervention**. It's an interesting experiment conducted by our team, + and for this purpose, we have developed a complete set of software infrastructure, including data filtering, + automatic training, and automated publishing. Therefore, if possible, we would appreciate more feedback or + suggestions as they are highly valuable to us. + + ## Why Can't the Desired Character Outfits Be Accurately Generated + + Our current training data is sourced from various image websites, and for a fully automated pipeline, + it's challenging to accurately predict which official images a character possesses. + Consequently, outfit generation relies on clustering based on labels from the training dataset + in an attempt to achieve the best possible recreation. We will continue to address this issue and attempt + optimization, but it remains a challenge that cannot be completely resolved. The accuracy of outfit + recreation is also unlikely to match the level achieved by manually trained models. + + In fact, this model's greatest strengths lie in recreating the inherent characteristics of the characters + themselves and its relatively strong generalization capabilities, owing to its larger dataset. + As such, **this model is well-suited for tasks such as changing outfits, posing characters, and, + of course, generating NSFW images of characters!**😉". + + For the following groups, it is not recommended to use this model and we express regret: + + 1. Individuals who cannot tolerate any deviations from the original character design, even in the slightest detail. + 2. Individuals who are facing the application scenarios with high demands for accuracy in recreating character outfits. + 3. Individuals who cannot accept the potential randomness in AI-generated images based on the Stable Diffusion algorithm. + 4. Individuals who are not comfortable with the fully automated process of training character models using LoRA, or those who believe that training character models must be done purely through manual operations to avoid disrespecting the characters. + 5. Individuals who finds the generated image content offensive to their values. + """ + model_name = model_name or try_find_title(char_name, game_name) or \ + try_get_title_from_repo(repo) or trigger_word.replace('_', ' ') + if not force_create_model: + try: + exist_model = civitai_find_online(model_name, creator='narugo1992') + except ModelNotFound: + model_id = None + else: + logging.info(f'Existing model {exist_model.model_name}({exist_model.model_id}) found.') + model_id = exist_model.model_id + else: + model_id = None + + model_id, _ = civitai_upsert_model( + name=model_name, + description_md=model_desc_md or model_desc_default, + tags=[ + game_name, f"{game_name} {char_name}", char_name, + 'female', 'girl', 'character', 'fully-automated', + *map(_tag_decode, core_tags.keys()), + ], + exist_model_id=model_id, + session=session, + ) + + version_data = civitai_create_version( + model_id=model_id, + version_name=version_name, + description_md=version_desc_md or '', + trigger_words=[ + trigger_word, + repr_tags([key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])]), + ], + session=session, + steps=step, + epochs=epoch, + ) + version_id = version_data['id'] + + civitai_upload_models( + model_version_id=version_id, + model_files=models, + model_id=model_id, + session=session, + ) + civitai_upload_images( + model_version_id=version_id, + image_files=images, + tags=[ + game_name, f"{game_name} {char_name}", char_name, + 'female', 'girl', 'character', 'fully-automated', 'random prompt', 'random seed', + *map(_tag_decode, core_tags.keys()), + ], + model_id=model_id, + session=session, + ) + + if draft: + logging.info(f'Draft of model {model_id!r} created.') + else: + civiti_publish(model_id, version_id, publish_at, session) + return civitai_get_model_info(model_id, session)['id'] + + +def get_draft_models(session=None): + session = session or get_civitai_session() + resp = srequest( + session, 'GET', 'https://civitai.com/api/trpc/model.getMyDraftModels', + params={ + 'input': json.dumps({"json": {"page": 1, "limit": 200, "authed": True}}), + }, + headers={'Referer': f'https://civitai.com/user'}, + ) + return resp.json()['result']['data']['json']['items'] + + +def delete_model(model_id: int, session=None): + session = session or get_civitai_session() + resp = srequest( + session, 'POST', 'https://civitai.com/api/trpc/model.delete', + json={"json": {"id": model_id, "permanently": False, "authed": True}}, + headers={'Referer': f'https://civitai.com/models/{model_id}'}, + ) + resp.raise_for_status() diff --git a/cyberharem/publish/convert.py b/cyberharem/publish/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..32a0cb928532cad136a4e32b1ab2e6d82dd34c44 --- /dev/null +++ b/cyberharem/publish/convert.py @@ -0,0 +1,19 @@ +import logging + +from hcpdiff.ckpt_manager import auto_manager +from hcpdiff.tools.lora_convert import LoraConverter + + +def convert_to_webui_lora(lora_path, lora_path_TE, dump_path, auto_scale_alpha: bool = True): + converter = LoraConverter() + + # load lora model + logging.info(f'Converting lora model {lora_path!r} and {lora_path_TE!r} to {dump_path!r} ...') + ckpt_manager = auto_manager(lora_path)() + + sd_unet = ckpt_manager.load_ckpt(lora_path) + sd_TE = ckpt_manager.load_ckpt(lora_path_TE) + state = converter.convert_to_webui(sd_unet['lora'], sd_TE['lora']) + if auto_scale_alpha: + state = {k: (v * v.shape[1] if 'lora_up' in k else v) for k, v in state.items()} + ckpt_manager._save_ckpt(state, save_path=dump_path) diff --git a/cyberharem/publish/cyberharem_publish_huggingface.py b/cyberharem/publish/cyberharem_publish_huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..58cb738c672503ebcf15049d621b78b8f8d1c182 --- /dev/null +++ b/cyberharem/publish/cyberharem_publish_huggingface.py @@ -0,0 +1,120 @@ +import datetime +import os +import pathlib +import pytz +from typing import Optional + +from ditk import logging +from hbutils.system import TemporaryDirectory +from huggingface_hub import CommitOperationAdd, CommitOperationDelete +from huggingface_hub.utils import RepositoryNotFoundError + +from .export import export_workdir, _GITLFS +from .steps import find_steps_in_workdir +from ..infer.draw import _DEFAULT_INFER_MODEL +from ..utils import get_hf_client, get_hf_fs + + +def deploy_to_huggingface(workdir: str, repository=None, revision: str = 'main', n_repeats: int = 3, + pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2, + image_width: int = 512, image_height: int = 768, infer_steps: int = 30, + lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras', + model_hash: Optional[str] = None, ds_dir: str = None): + name, _ = find_steps_in_workdir(workdir) + repository = repository or f'AppleHarem/{name}' + + logging.info(f'Initializing repository {repository!r} ...') + hf_client = get_hf_client() + hf_fs = get_hf_fs() + if not hf_fs.exists(f'{repository}/.gitattributes'): + hf_client.create_repo(repo_id=repository, repo_type='model', exist_ok=True) + + if not hf_fs.exists(f'{repository}/.gitattributes') or \ + '*.png filter=lfs diff=lfs merge=lfs -text' not in hf_fs.read_text(f'{repository}/.gitattributes'): + logging.info(f'Preparing for lfs attributes of repository {repository!r}.') + with TemporaryDirectory() as td: + _git_attr_file = os.path.join(td, '.gitattributes') + with open(_git_attr_file, 'w', encoding='utf-8') as f: + print(_GITLFS, file=f) + + operations = [ + CommitOperationAdd( + path_in_repo='.gitattributes', + path_or_fileobj=_git_attr_file, + ) + ] + tokyo_tz = pytz.timezone('Asia/Tokyo') + current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z') + commit_message = f'Update {name}\'s .gitattributes, on {current_time}' + logging.info(f'Updating {name}\'s .gitattributes to repository {repository!r} ...') + hf_client.create_commit( + repository, + operations, + commit_message=commit_message, + repo_type='model', + revision=revision, + ) + + with TemporaryDirectory() as td: + export_workdir( + workdir, td, n_repeats, pretrained_model, + clip_skip, image_width, image_height, infer_steps, + lora_alpha, sample_method, model_hash, ds_repo=ds_dir, # ds_repo: 本地数据集或远端数据集 + ) + + try: + hf_client.repo_info(repo_id=repository, repo_type='dataset') + except RepositoryNotFoundError: + has_dataset_repo = False + else: + has_dataset_repo = True + + readme_text = pathlib.Path(os.path.join(td, 'README.md')).read_text(encoding='utf-8') + with open(os.path.join(td, 'README.md'), 'w', encoding='utf-8') as f: + print('---', file=f) + print('license: mit', file=f) + if has_dataset_repo: + print('datasets:', file=f) + print(f'- {repository}', file=f) + print('pipeline_tag: text-to-image', file=f) + print('tags:', file=f) + print('- art', file=f) + print('---', file=f) + print('', file=f) + print(readme_text, file=f) + + _exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')] + _exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1]) + pre_exist_files = set() + for i, (file, segments) in enumerate(_exist_ps): + if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]: + continue + if file != '.': + pre_exist_files.add(file) + + operations = [] + for directory, _, files in os.walk(td): + for file in files: + filename = os.path.abspath(os.path.join(td, directory, file)) + file_in_repo = os.path.relpath(filename, td) + operations.append(CommitOperationAdd( + path_in_repo=file_in_repo, + path_or_fileobj=filename, + )) + if file_in_repo in pre_exist_files: + pre_exist_files.remove(file_in_repo) + logging.info(f'Useless files: {sorted(pre_exist_files)} ...') + for file in sorted(pre_exist_files): + operations.append(CommitOperationDelete(path_in_repo=file)) + + tokyo_tz = pytz.timezone('Asia/Tokyo') + current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z') + commit_message = f'Publish {name}\'s lora, on {current_time}' + logging.info(f'Publishing {name}\'s lora to repository {repository!r} ...') + hf_client.create_commit( + repository, + operations, + commit_message=commit_message, + repo_type='model', + revision=revision, + ) diff --git a/cyberharem/publish/export.py b/cyberharem/publish/export.py new file mode 100644 index 0000000000000000000000000000000000000000..e7c0dd9c37b59abc5effe150c893c465c164fc44 --- /dev/null +++ b/cyberharem/publish/export.py @@ -0,0 +1,284 @@ +import json +import logging +import os.path +import shutil +import time +import zipfile +from textwrap import dedent +from typing import Optional + +import numpy as np +import pandas as pd +from imgutils.metrics import ccip_extract_feature, ccip_batch_same +from tqdm.auto import tqdm +from waifuc.source import LocalSource + +try: + import torch +except (ImportError, ModuleNotFoundError): + torch = None + +from .convert import convert_to_webui_lora +from .steps import find_steps_in_workdir +from ..dataset import load_dataset_for_character +from ..dataset.tags import sort_draw_names +from ..infer.draw import _DEFAULT_INFER_MODEL +from ..infer.draw import draw_with_workdir +from ..utils import repr_tags, load_tags_from_directory + +KNOWN_MODEL_HASHES = { + 'AIARTCHAN/anidosmixV2': 'EB49192009', + 'stablediffusionapi/anything-v5': None, + 'stablediffusionapi/cetusmix': 'B42B09FF12', + 'Meina/MeinaMix_V10': 'D967BCAE4A', + 'Meina/MeinaMix_V11': '54EF3E3610', + 'Lykon/DreamShaper': 'C33104F6', + 'digiplay/majicMIX_realistic_v6': 'EBDB94D4', + 'stablediffusionapi/abyssorangemix2nsfw': 'D6992792', + 'AIARTCHAN/expmixLine_v2': 'D91B18D1', + 'Yntec/CuteYuki2': 'FBE372BA', + 'stablediffusionapi/counterfeit-v30': '12047227', + 'jzli/XXMix_9realistic-v4': '5D22F204', + 'stablediffusionapi/flat-2d-animerge': 'F279CF76', + 'redstonehero/cetusmix_v4': '838408E0', + 'Meina/Unreal_V4.1': '0503BFAD', + 'Meina/MeinaHentai_V4': '39C0C3B6', + 'Meina/MeinaPastel_V6': 'DA1D535E', + 'KBlueLeaf/kohaku-v4-rev1.2': '87F9E45D', + 'stablediffusionapi/night-sky-yozora-sty': 'D31F707A', +} + +EXPORT_MARK = 'v1.4.1' + +_GITLFS = dedent(""" +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +""").strip() + + +def export_workdir(workdir: str, export_dir: str, n_repeats: int = 2, + pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2, + image_width: int = 512, image_height: int = 768, infer_steps: int = 30, + lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras', + model_hash: Optional[str] = None, ds_repo: Optional[str] = None): + name, steps = find_steps_in_workdir(workdir) + logging.info(f'Starting export trained artifacts of {name!r}, with steps: {steps!r}') + model_hash = model_hash or KNOWN_MODEL_HASHES.get(pretrained_model, None) + if model_hash: + logging.info(f'Model hash {model_hash!r} detected for model {pretrained_model!r}.') + + if os.path.exists(os.path.join(workdir, 'meta.json')): + with open(os.path.join(workdir, 'meta.json'), 'r', encoding='utf-8') as f: + dataset_info = json.load(f)['dataset'] + else: + dataset_info = None + + ds_repo = ds_repo or f'AppleHarem/{name}' + ds_size = (384, 512) if not dataset_info or not dataset_info['type'] else dataset_info['type'] + logging.info(f'Loading dataset {ds_repo!r}, {ds_size!r} ...') + with load_dataset_for_character(ds_repo, ds_size) as (ch, ds_dir): + core_tags, _ = load_tags_from_directory(ds_dir) + ds_source = LocalSource(ds_dir) + ds_feats = [] + for item in tqdm(list(ds_source), desc='Extract Dataset Feature'): + ds_feats.append(ccip_extract_feature(item.image)) + + d_names = set() + all_drawings = {} + nsfw_count = {} + all_scores = {} + all_scores_lst = [] + for step in steps: + logging.info(f'Exporting for {name}-{step} ...') + step_dir = os.path.join(export_dir, f'{step}') + os.makedirs(step_dir, exist_ok=True) + + preview_dir = os.path.join(step_dir, 'previews') + os.makedirs(preview_dir, exist_ok=True) + + while True: + try: + drawings = draw_with_workdir( + workdir, model_steps=step, n_repeats=n_repeats, + pretrained_model=pretrained_model, + width=image_width, height=image_height, infer_steps=infer_steps, + lora_alpha=lora_alpha, clip_skip=clip_skip, sample_method=sample_method, + model_hash=model_hash, + ) + except RuntimeError: + n_repeats += 1 + else: + break + + all_image_files = [] + image_feats = [] + for draw in drawings: + img_file = os.path.join(preview_dir, f'{draw.name}.png') + image_feats.append(ccip_extract_feature(draw.image)) + draw.image.save(img_file, pnginfo=draw.pnginfo) + all_image_files.append(img_file) + + with open(os.path.join(preview_dir, f'{draw.name}_info.txt'), 'w', encoding='utf-8') as f: + print(draw.preview_info, file=f) + d_names.add(draw.name) + all_drawings[(draw.name, step)] = draw + if not draw.sfw: + nsfw_count[draw.name] = nsfw_count.get(draw.name, 0) + 1 + + pt_file = os.path.join(workdir, 'ckpts', f'{name}-{step}.pt') + unet_file = os.path.join(workdir, 'ckpts', f'unet-{step}.safetensors') + text_encoder_file = os.path.join(workdir, 'ckpts', f'text_encoder-{step}.safetensors') + raw_dir = os.path.join(step_dir, 'raw') + os.makedirs(raw_dir, exist_ok=True) + shutil.copyfile(pt_file, os.path.join(raw_dir, os.path.basename(pt_file))) + shutil.copyfile(unet_file, os.path.join(raw_dir, os.path.basename(unet_file))) + shutil.copyfile(text_encoder_file, os.path.join(raw_dir, os.path.basename(text_encoder_file))) + + shutil.copyfile(pt_file, os.path.join(step_dir, f'{name}.pt')) + convert_to_webui_lora(unet_file, text_encoder_file, os.path.join(step_dir, f'{name}.safetensors')) + with zipfile.ZipFile(os.path.join(step_dir, f'{name}.zip'), 'w') as zf: + zf.write(os.path.join(step_dir, f'{name}.pt'), f'{name}.pt') + zf.write(os.path.join(step_dir, f'{name}.safetensors'), f'{name}.safetensors') + for img_file in all_image_files: + zf.write(img_file, os.path.basename(img_file)) + + same_matrix = ccip_batch_same([*image_feats, *ds_feats]) + score = same_matrix[:len(image_feats), len(image_feats):].mean() + all_scores[step] = score + all_scores_lst.append(score) + logging.info(f'Score of step {step} is {score}.') + + lst_scores = np.array(all_scores_lst) + lst_steps = np.array(steps) + if dataset_info and 'size' in dataset_info: + min_best_steps = 6 * dataset_info['size'] + _lst_scores = lst_scores[lst_steps >= min_best_steps] + _lst_steps = lst_steps[lst_steps >= min_best_steps] + if _lst_scores.shape[0] > 0: + lst_steps, lst_scores = _lst_steps, _lst_scores + + best_idx = np.argmax(lst_scores) + best_step = lst_steps[best_idx].item() + nsfw_ratio = {name: count * 1.0 / len(steps) for name, count in nsfw_count.items()} + with open(os.path.join(export_dir, 'meta.json'), 'w', encoding='utf-8') as f: + json.dump({ + 'name': name, + 'steps': steps, + 'mark': EXPORT_MARK, + 'time': time.time(), + 'dataset': dataset_info, + 'scores': [ + { + 'step': step, + 'score': score, + } for step, score in sorted(all_scores.items()) + ], + 'best_step': best_step, + }, f, ensure_ascii=False, indent=4) + with open(os.path.join(export_dir, '.gitattributes'), 'w', encoding='utf-8') as f: + print(_GITLFS, file=f) + with open(os.path.join(export_dir, 'README.md'), 'w', encoding='utf-8') as f: + print(f'# Lora of {name}', file=f) + print('', file=f) + + print('This model is trained with [HCP-Diffusion](https://github.com/7eu7d7/HCP-Diffusion). ' + 'And the auto-training framework is maintained by ' + '[DeepGHS Team](https://huggingface.co/deepghs).' + 'And the WebUI Panel provid by [LittleAppleWebUI](https://github.com/LittleApple-fp16/LittleAppleWebUI)', file=f) + print('', file=f) + + print('The base model used during training is [NAI](https://huggingface.co/deepghs/animefull-latest), ' + f'and the base model used for generating preview images is ' + f'[{pretrained_model}](https://huggingface.co/{pretrained_model}).', file=f) + print('', file=f) + + print(f'After downloading the pt and safetensors files for the specified step, ' + f'you need to use them simultaneously. The pt file will be used as an embedding, ' + f'while the safetensors file will be loaded for Lora.', file=f) + print('', file=f) + print(f'For example, if you want to use the model from step {best_step}, ' + f'you need to download `{best_step}/{name}.pt` as the embedding and ' + f'`{best_step}/{name}.safetensors` for loading Lora. ' + f'By using both files together, you can generate images for the desired characters.', file=f) + print('', file=f) + + print(dedent(f""" +**The best step we recommend is {best_step}**, with the score of {all_scores[best_step]:.3f}. The trigger words are: +1. `{name}` +2. `{repr_tags([key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])])}` + """).strip(), file=f) + print('', file=f) + + print(dedent(""" +For the following groups, it is not recommended to use this model and we express regret: +1. Individuals who cannot tolerate any deviations from the original character design, even in the slightest detail. +2. Individuals who are facing the application scenarios with high demands for accuracy in recreating character outfits. +3. Individuals who cannot accept the potential randomness in AI-generated images based on the Stable Diffusion algorithm. +4. Individuals who are not comfortable with the fully automated process of training character models using LoRA, or those who believe that training character models must be done purely through manual operations to avoid disrespecting the characters. +5. Individuals who finds the generated image content offensive to their values. + """).strip(), file=f) + print('', file=f) + + print(f'These are available steps:', file=f) + print('', file=f) + + d_names = sort_draw_names(list(d_names)) + columns = ['Steps', 'Score', 'Download', *d_names] + t_data = [] + + for step in steps[::-1]: + d_mds = [] + for dname in d_names: + file = os.path.join(str(step), 'previews', f'{dname}.png') + if (dname, step) in all_drawings: + if nsfw_ratio.get(dname, 0.0) < 0.35: + d_mds.append(f'![{dname}-{step}]({file})') + else: + d_mds.append(f'[]({file})') + else: + d_mds.append('') + + t_data.append(( + str(step) if step != best_step else f'**{step}**', + f'{all_scores[step]:.3f}' if step != best_step else f'**{all_scores[step]:.3f}**', + f'[Download]({step}/{name}.zip)' if step != best_step else f'[**Download**]({step}/{name}.zip)', + *d_mds, + )) + + df = pd.DataFrame(columns=columns, data=t_data) + print(df.to_markdown(index=False), file=f) + print('', file=f) diff --git a/cyberharem/publish/huggingface.py b/cyberharem/publish/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..58cb738c672503ebcf15049d621b78b8f8d1c182 --- /dev/null +++ b/cyberharem/publish/huggingface.py @@ -0,0 +1,120 @@ +import datetime +import os +import pathlib +import pytz +from typing import Optional + +from ditk import logging +from hbutils.system import TemporaryDirectory +from huggingface_hub import CommitOperationAdd, CommitOperationDelete +from huggingface_hub.utils import RepositoryNotFoundError + +from .export import export_workdir, _GITLFS +from .steps import find_steps_in_workdir +from ..infer.draw import _DEFAULT_INFER_MODEL +from ..utils import get_hf_client, get_hf_fs + + +def deploy_to_huggingface(workdir: str, repository=None, revision: str = 'main', n_repeats: int = 3, + pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2, + image_width: int = 512, image_height: int = 768, infer_steps: int = 30, + lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras', + model_hash: Optional[str] = None, ds_dir: str = None): + name, _ = find_steps_in_workdir(workdir) + repository = repository or f'AppleHarem/{name}' + + logging.info(f'Initializing repository {repository!r} ...') + hf_client = get_hf_client() + hf_fs = get_hf_fs() + if not hf_fs.exists(f'{repository}/.gitattributes'): + hf_client.create_repo(repo_id=repository, repo_type='model', exist_ok=True) + + if not hf_fs.exists(f'{repository}/.gitattributes') or \ + '*.png filter=lfs diff=lfs merge=lfs -text' not in hf_fs.read_text(f'{repository}/.gitattributes'): + logging.info(f'Preparing for lfs attributes of repository {repository!r}.') + with TemporaryDirectory() as td: + _git_attr_file = os.path.join(td, '.gitattributes') + with open(_git_attr_file, 'w', encoding='utf-8') as f: + print(_GITLFS, file=f) + + operations = [ + CommitOperationAdd( + path_in_repo='.gitattributes', + path_or_fileobj=_git_attr_file, + ) + ] + tokyo_tz = pytz.timezone('Asia/Tokyo') + current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z') + commit_message = f'Update {name}\'s .gitattributes, on {current_time}' + logging.info(f'Updating {name}\'s .gitattributes to repository {repository!r} ...') + hf_client.create_commit( + repository, + operations, + commit_message=commit_message, + repo_type='model', + revision=revision, + ) + + with TemporaryDirectory() as td: + export_workdir( + workdir, td, n_repeats, pretrained_model, + clip_skip, image_width, image_height, infer_steps, + lora_alpha, sample_method, model_hash, ds_repo=ds_dir, # ds_repo: 本地数据集或远端数据集 + ) + + try: + hf_client.repo_info(repo_id=repository, repo_type='dataset') + except RepositoryNotFoundError: + has_dataset_repo = False + else: + has_dataset_repo = True + + readme_text = pathlib.Path(os.path.join(td, 'README.md')).read_text(encoding='utf-8') + with open(os.path.join(td, 'README.md'), 'w', encoding='utf-8') as f: + print('---', file=f) + print('license: mit', file=f) + if has_dataset_repo: + print('datasets:', file=f) + print(f'- {repository}', file=f) + print('pipeline_tag: text-to-image', file=f) + print('tags:', file=f) + print('- art', file=f) + print('---', file=f) + print('', file=f) + print(readme_text, file=f) + + _exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')] + _exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1]) + pre_exist_files = set() + for i, (file, segments) in enumerate(_exist_ps): + if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]: + continue + if file != '.': + pre_exist_files.add(file) + + operations = [] + for directory, _, files in os.walk(td): + for file in files: + filename = os.path.abspath(os.path.join(td, directory, file)) + file_in_repo = os.path.relpath(filename, td) + operations.append(CommitOperationAdd( + path_in_repo=file_in_repo, + path_or_fileobj=filename, + )) + if file_in_repo in pre_exist_files: + pre_exist_files.remove(file_in_repo) + logging.info(f'Useless files: {sorted(pre_exist_files)} ...') + for file in sorted(pre_exist_files): + operations.append(CommitOperationDelete(path_in_repo=file)) + + tokyo_tz = pytz.timezone('Asia/Tokyo') + current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z') + commit_message = f'Publish {name}\'s lora, on {current_time}' + logging.info(f'Publishing {name}\'s lora to repository {repository!r} ...') + hf_client.create_commit( + repository, + operations, + commit_message=commit_message, + repo_type='model', + revision=revision, + ) diff --git a/cyberharem/publish/steps.py b/cyberharem/publish/steps.py new file mode 100644 index 0000000000000000000000000000000000000000..dfad692e95a9a3f76b5441c4a6bc0a720551516f --- /dev/null +++ b/cyberharem/publish/steps.py @@ -0,0 +1,32 @@ +import glob +import os.path +from typing import List, Tuple + + +def find_steps_in_workdir(workdir: str) -> Tuple[str, List[int]]: + ckpts_dir = os.path.join(workdir, 'ckpts') + pt_steps = [] + pt_name = None + for pt in glob.glob(os.path.join(ckpts_dir, '*-*.pt')): + name = os.path.basename(pt) + segs = os.path.splitext(name)[0].split('-') + if pt_name is None: + pt_name = '-'.join(segs[:-1]) + else: + if pt_name != '-'.join(segs[:-1]): + raise NameError(f'Name not match, {pt_name!r} vs {"-".join(segs[:-1])!r}.') + pt_steps.append(int(segs[-1])) + + unet_steps = [] + for unet in glob.glob(os.path.join(ckpts_dir, 'unet-*.safetensors')): + name = os.path.basename(unet) + segs = os.path.splitext(name)[0].split('-') + unet_steps.append(int(segs[-1])) + + text_encoder_steps = [] + for text_encoder in glob.glob(os.path.join(ckpts_dir, 'text_encoder-*.safetensors')): + name = os.path.basename(text_encoder) + segs = os.path.splitext(name)[0].split('-') + text_encoder_steps.append(int(segs[-1])) + + return pt_name, sorted(set(pt_steps) & set(unet_steps) & set(text_encoder_steps)) diff --git a/cyberharem/train/__init__.py b/cyberharem/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b45a4cce64973ba845b7f6c4718b5cd981fe29dd --- /dev/null +++ b/cyberharem/train/__init__.py @@ -0,0 +1,2 @@ +from .embedding import create_embedding +from .train import train_plora diff --git a/cyberharem/train/__main__.py b/cyberharem/train/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7583a6aa5ff96c43813ac0f0e0af33ce645cab --- /dev/null +++ b/cyberharem/train/__main__.py @@ -0,0 +1,74 @@ +import os.path +from functools import partial + +import click +from ditk import logging +from gchar.generic import import_generic +from gchar.utils import GLOBAL_CONTEXT_SETTINGS +from gchar.utils import print_version as _origin_print_version +from huggingface_hub import hf_hub_url +from tqdm.auto import tqdm + +from cyberharem.dataset import save_recommended_tags +from cyberharem.publish import find_steps_in_workdir +from ..utils import get_hf_fs, download_file + +print_version = partial(_origin_print_version, 'cyberharem.train') + +import_generic() + + +@click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish trained models') +@click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True) +def cli(): + pass # pragma: no cover + + +@cli.command('download', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Download trained ckpts from huggingface.') +@click.option('-r', '--repository', 'repository', type=str, required=True, + help='Repository.', show_default=True) +@click.option('-w', '--workdir', 'workdir', type=str, default=None, + help='Work directory', show_default=True) +@click.option('--no-tags', 'no_tags', is_flag=True, type=bool, default=False, + help='Do not generate tags.', show_default=True) +def download(repository, workdir, no_tags): + logging.try_init_root(logging.INFO) + workdir = workdir or os.path.join('runs', repository.split('/')[-1]) + + logging.info(f'Downloading models for {workdir!r} ...') + hf_fs = get_hf_fs() + for f in tqdm(hf_fs.glob(f'{repository}/*/raw/*')): + rel_file = os.path.relpath(f, repository) + local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file)) + if os.path.dirname(local_file): + os.makedirs(os.path.dirname(local_file), exist_ok=True) + download_file( + hf_hub_url(repository, filename=rel_file), + local_file + ) + + if not no_tags: + logging.info(f'Regenerating tags for {workdir!r} ...') + pt_name, _ = find_steps_in_workdir(workdir) + game_name = pt_name.split('_')[-1] + name = '_'.join(pt_name.split('_')[:-1]) + + from gchar.games.dispatch.access import GAME_CHARS + if game_name in GAME_CHARS: + ch_cls = GAME_CHARS[game_name] + ch = ch_cls.get(name) + else: + ch = None + + if ch is None: + source = repository + else: + source = ch + + logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.') + save_recommended_tags(source, name=pt_name, workdir=workdir) + logging.info('Success!') + + +if __name__ == '__main__': + cli() diff --git a/cyberharem/train/__pycache__/__init__.cpython-310.pyc b/cyberharem/train/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aee987a857c47d6842a57e80644b69d2341b4761 Binary files /dev/null and b/cyberharem/train/__pycache__/__init__.cpython-310.pyc differ diff --git a/cyberharem/train/__pycache__/embedding.cpython-310.pyc b/cyberharem/train/__pycache__/embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3caf9367ea5e10c22d18f90dd22686b7f2a33d73 Binary files /dev/null and b/cyberharem/train/__pycache__/embedding.cpython-310.pyc differ diff --git a/cyberharem/train/__pycache__/train.cpython-310.pyc b/cyberharem/train/__pycache__/train.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eb538f1194b1cf3e5abff7ebc3edf09cc76cc32 Binary files /dev/null and b/cyberharem/train/__pycache__/train.cpython-310.pyc differ diff --git a/cyberharem/train/embedding.py b/cyberharem/train/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7d8047f3f0ec68b31b1e08d55b80e2bc18dc0dc5 --- /dev/null +++ b/cyberharem/train/embedding.py @@ -0,0 +1,12 @@ +from hcpdiff.tools.create_embedding import PTCreator + +_DEFAULT_EMBEDDING_DIR = 'embs' +_DEFAULT_TRAIN_MODEL = 'deepghs/animefull-latest' + + +def create_embedding( + name: str, n_words: int = 4, init_text: str = '*0.017', replace: bool = False, + pretrained_model: str = _DEFAULT_TRAIN_MODEL, embs_dir: str = _DEFAULT_TRAIN_MODEL +): + pt_creator = PTCreator(pretrained_model, embs_dir) + pt_creator.creat_word_pt(name, n_words, init_text, replace) diff --git a/cyberharem/train/train.py b/cyberharem/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b40e3df5271e66dcb3a3ad09ba43c2f3a8f3e623 --- /dev/null +++ b/cyberharem/train/train.py @@ -0,0 +1,126 @@ +import glob +import json +import logging +import math +import os.path +from typing import Optional, Union + +from gchar.games.base import Character +from hbutils.string import plural_word +from hbutils.system import TemporaryDirectory +from hcpdiff.train_ac import Trainer +from hcpdiff.train_ac_single import TrainerSingleCard +from hcpdiff.utils import load_config_with_cli + +from .embedding import create_embedding, _DEFAULT_TRAIN_MODEL +from ..dataset import load_dataset_for_character, save_recommended_tags +from ..utils import data_to_cli_args, get_ch_name + +_DEFAULT_TRAIN_CFG = 'cfgs/train/examples/lora_anime_character.yaml' + + +def _min_training_steps(dataset_size: int, unit: int = 20): + steps = 4000.9 + (720.9319 - 4000.9) / (1 + (dataset_size / 297.2281) ** 0.6543184) + return int(round(steps / unit)) * unit + + +def train_plora( + source: Union[str, Character], name: Optional[str] = None, + epochs: int = 13, min_steps: Optional[int] = None, + save_for_times: int = 15, no_min_steps: bool = False, + batch_size: int = 4, pretrained_model: str = _DEFAULT_TRAIN_MODEL, + workdir: str = None, emb_n_words: int = 4, emb_init_text: str = '*[0.017, 1]', + unet_rank: float = 8, text_encoder_rank: float = 4, + cfg_file: str = _DEFAULT_TRAIN_CFG, single_card: bool = True, + dataset_type: str = 'stage3-1200', use_ratio: bool = True, +): + with load_dataset_for_character(source, dataset_type) as (ch, ds_dir): + if ch is None: + if name is None: + raise ValueError(f'Name should be specified when using custom source - {source!r}.') + else: + name = name or get_ch_name(ch) + + dataset_size = len(glob.glob(os.path.join(ds_dir, '*.png'))) + logging.info(f'{plural_word(dataset_size, "image")} found in dataset.') + + actual_steps = epochs * dataset_size + if not no_min_steps: + actual_steps = max(actual_steps, _min_training_steps(dataset_size, 20)) + if min_steps is not None: + actual_steps = max(actual_steps, min_steps) + save_per_steps = max(int(math.ceil(actual_steps / save_for_times / 20) * 20), 20) + steps = int(math.ceil(actual_steps / save_per_steps) * save_per_steps) + epochs = int(math.ceil(steps / dataset_size)) + logging.info(f'Training for {plural_word(steps, "step")}, {plural_word(epochs, "epoch")}, ' + f'save per {plural_word(save_per_steps, "step")} ...') + + workdir = workdir or os.path.join('runs', name) + os.makedirs(workdir, exist_ok=True) + # os.makedirs(workdir) + save_recommended_tags(ds_dir, name, workdir) + with open(os.path.join(workdir, 'meta.json'), 'w', encoding='utf-8') as f: + json.dump({ + 'dataset': { + 'size': dataset_size, + 'type': dataset_type, + }, + }, f, indent=4, sort_keys=True, ensure_ascii=False) + + with TemporaryDirectory() as embs_dir: + logging.info(f'Creating embeddings {name!r} at {embs_dir!r}, ' + f'n_words: {emb_n_words!r}, init_text: {emb_init_text!r}, ' + f'pretrained_model: {pretrained_model!r}.') + create_embedding( + name, emb_n_words, emb_init_text, + replace=True, + pretrained_model=pretrained_model, + embs_dir=embs_dir, + ) + + cli_args = data_to_cli_args({ + 'train': { + 'train_steps': steps, + 'save_step': save_per_steps, + 'scheduler': { + 'num_training_steps': steps, + } + }, + 'model': { + 'pretrained_model_name_or_path': pretrained_model, + }, + 'character_name': name, + 'dataset_dir': ds_dir, + 'exp_dir': workdir, + 'unet_rank': unet_rank, + 'text_encoder_rank': text_encoder_rank, + 'tokenizer_pt': { + 'emb_dir': embs_dir, + }, + 'data': { + 'dataset1': { + 'batch_size': batch_size, + + 'bucket': { + '_target_': 'hcpdiff.data.bucket.RatioBucket.from_files', + 'target_area': '${times:512,512}', + 'num_bucket': 5, + } if use_ratio else { + '_target_': 'hcpdiff.data.bucket.SizeBucket.from_files', + 'target_area': '---', + 'num_bucket': 1, + } + }, + }, + }) + conf = load_config_with_cli(cfg_file, args_list=cli_args) # skip --cfg + + logging.info(f'Training with {cfg_file!r}, args: {cli_args!r} ...') + if single_card: + logging.info('Training with single card ...') + trainer = TrainerSingleCard(conf) + else: + logging.info('Training with non-single cards ...') + trainer = Trainer(conf) + + trainer.train() diff --git a/cyberharem/utils/__init__.py b/cyberharem/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4689249c39a57996cac9252a701bdc473840c8bd --- /dev/null +++ b/cyberharem/utils/__init__.py @@ -0,0 +1,7 @@ +from .character import get_ch_name, get_alphabet_name, get_pure_name +from .config import data_to_cli_args +from .download import download_file +from .huggingface import number_to_tag, get_hf_fs, get_hf_client +from .session import get_civitai_session, get_requests_session, srequest +from .tags import find_core_tags, load_tags_from_directory, repr_tags +from .time import parse_time diff --git a/cyberharem/utils/__pycache__/__init__.cpython-310.pyc b/cyberharem/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f1f703a91ae0cac5beee1ec9749858419644044 Binary files /dev/null and b/cyberharem/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/cyberharem/utils/__pycache__/character.cpython-310.pyc b/cyberharem/utils/__pycache__/character.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..757d9f1765e34efb1f1e70c88f216cc8e9ec80ff Binary files /dev/null and b/cyberharem/utils/__pycache__/character.cpython-310.pyc differ diff --git a/cyberharem/utils/__pycache__/config.cpython-310.pyc b/cyberharem/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9ad8abc90a7a7feb85fadd4804c5fd00608d14b Binary files /dev/null and b/cyberharem/utils/__pycache__/config.cpython-310.pyc differ diff --git a/cyberharem/utils/__pycache__/download.cpython-310.pyc b/cyberharem/utils/__pycache__/download.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccfbabd8444ad10d830682b9092fc231c029a61c Binary files /dev/null and b/cyberharem/utils/__pycache__/download.cpython-310.pyc differ diff --git a/cyberharem/utils/__pycache__/huggingface.cpython-310.pyc b/cyberharem/utils/__pycache__/huggingface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2766cac248913b1412e055903ab2eb03a4fab614 Binary files /dev/null and b/cyberharem/utils/__pycache__/huggingface.cpython-310.pyc differ diff --git a/cyberharem/utils/__pycache__/session.cpython-310.pyc b/cyberharem/utils/__pycache__/session.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f684e9b507257e0c868a566556f23aed21916f2e Binary files /dev/null and b/cyberharem/utils/__pycache__/session.cpython-310.pyc differ diff --git a/cyberharem/utils/__pycache__/tags.cpython-310.pyc b/cyberharem/utils/__pycache__/tags.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfb0da7f2efb9924fb7ed01b5c9392c9470d0b80 Binary files /dev/null and b/cyberharem/utils/__pycache__/tags.cpython-310.pyc differ diff --git a/cyberharem/utils/__pycache__/time.cpython-310.pyc b/cyberharem/utils/__pycache__/time.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52997d3465d6a7339b96ca090b7a20e33eb6a82d Binary files /dev/null and b/cyberharem/utils/__pycache__/time.cpython-310.pyc differ diff --git a/cyberharem/utils/character.py b/cyberharem/utils/character.py new file mode 100644 index 0000000000000000000000000000000000000000..b21d7a838b5c4b259586faf6c656c1e1f8039e2c --- /dev/null +++ b/cyberharem/utils/character.py @@ -0,0 +1,36 @@ +import re + +from gchar.games.base import Character +from thefuzz import fuzz + + +def get_pure_name(name: str) -> str: + return '_'.join([word for word in re.split(r'[\W_]+', name.lower()) if word]) + + +def get_alphabet_name(name: str) -> str: + return '_'.join(re.findall(r'[a-zA-Z\d+]+', name.lower())) + + +def _name_alphabet_ratio(name: str) -> float: + pure_name = get_pure_name(name) + alphabet_name = get_alphabet_name(name) + return fuzz.token_set_ratio(pure_name, alphabet_name) + + +def get_ch_name(ch: Character): + names = [ + *map(str, ch.ennames), + *map(str, ch.cnnames), + *map(str, ch.jpnames), + ] + all_names = [(name, _name_alphabet_ratio(name), i) for i, name in enumerate(names)] + all_names = sorted(all_names, key=lambda x: (-x[1], x[2])) + + name, ratio, _ = all_names[0] + if ratio >= 0.9: + short_name = get_alphabet_name(name) + else: + raise ValueError(f'No suitable alphabet-based name for {ch!r}.') + + return f'{short_name}_{ch.__game_name__}' diff --git a/cyberharem/utils/cli.py b/cyberharem/utils/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdd1bb02ab6f6d7586fd14894ebfea8ecdf2a7a --- /dev/null +++ b/cyberharem/utils/cli.py @@ -0,0 +1,51 @@ +import types + +import click +from click.core import Context, Option + +from ..config.meta import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__ + +_raw_authors = [item.strip() for item in __AUTHOR__.split(',') if item.strip()] +_raw_emails = [item.strip() for item in __AUTHOR_EMAIL__.split(',')] +if len(_raw_emails) < len(_raw_authors): # pragma: no cover + _raw_emails += [None] * (len(_raw_authors) - len(_raw_emails)) +elif len(_raw_emails) > len(_raw_authors): # pragma: no cover + _raw_emails[len(_raw_authors) - 1] = tuple(_raw_emails[len(_raw_authors) - 1:]) + del _raw_emails[len(_raw_authors):] + +_author_tuples = [ + (author, tuple([item for item in (email if isinstance(email, tuple) else ((email,) if email else ())) if item])) + for author, email in zip(_raw_authors, _raw_emails) +] +_authors = [ + author if not emails else '{author} ({emails})'.format(author=author, emails=', '.join(emails)) + for author, emails in _author_tuples +] + +GLOBAL_CONTEXT_SETTINGS = dict( + help_option_names=['-h', '--help'] +) + + +def print_version(module, ctx: Context, param: Option, value: bool) -> None: + """ + Print version information of cli + :param module: current module using this cli. + :param ctx: click context + :param param: current parameter's metadata + :param value: value of current parameter + """ + if not value or ctx.resilient_parsing: + return # pragma: no cover + + if module is None: + title = __TITLE__ + elif isinstance(module, types.ModuleType): + title = module.__name__ + else: + title = str(module) + + click.echo('{title}, version {version}.'.format(title=title.capitalize(), version=__VERSION__)) + if _authors: + click.echo('Developed by {authors}.'.format(authors=', '.join(_authors))) + ctx.exit() diff --git a/cyberharem/utils/config.py b/cyberharem/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..176f89488997c09452cb26fcfddb3011a068d0f0 --- /dev/null +++ b/cyberharem/utils/config.py @@ -0,0 +1,20 @@ +import json +from typing import Mapping, Optional, List + + +def _yaml_recursive(data, segments: Optional[list] = None): + segments = list(segments or []) + if isinstance(data, Mapping): + for key, value in data.items(): + yield from _yaml_recursive(value, [*segments, key]) + elif isinstance(data, (list, tuple)): + for i, item in enumerate(data): + yield from _yaml_recursive(item, [*segments, i]) + else: + key = '.'.join(map(str, segments)) + value = json.dumps(data) + yield f'{key}={value}' + + +def data_to_cli_args(data) -> List[str]: + return list(_yaml_recursive(data)) diff --git a/cyberharem/utils/download.py b/cyberharem/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..547cba280b6829965d5302476e6e9d72ff2208e7 --- /dev/null +++ b/cyberharem/utils/download.py @@ -0,0 +1,77 @@ +import os +from contextlib import contextmanager + +import requests +from tqdm.auto import tqdm + +from .session import get_requests_session, srequest + + +class _FakeClass: + def update(self, *args, **kwargs): + pass + + +@contextmanager +def _with_tqdm(expected_size, desc, silent: bool = False): + """ + Context manager that provides a tqdm progress bar for tracking the download progress. + + :param expected_size: The expected size of the file being downloaded. + :type expected_size: int + :param desc: The description of the progress bar. + :type desc: str + :param silent: Whether to silence the progress bar. If True, a fake progress bar is used. (default: False) + :type silent: bool + """ + if not silent: + with tqdm(total=expected_size, unit='B', unit_scale=True, unit_divisor=1024, desc=desc) as pbar: + yield pbar + else: + yield _FakeClass() + + +def download_file(url, filename, expected_size: int = None, desc=None, session=None, silent: bool = False, **kwargs): + """ + Downloads a file from the given URL and saves it to the specified filename. + + :param url: The URL of the file to download. + :type url: str + :param filename: The filename to save the downloaded file to. + :type filename: str + :param expected_size: The expected size of the file in bytes. (default: None) + :type expected_size: int + :param desc: The description of the download progress. If not provided, the filename is used. (default: None) + :type desc: str + :param session: An existing requests Session object to use for the download. If not provided, a new Session object is created. (default: None) + :type session: requests.Session + :param silent: Whether to silence the progress bar. If True, no progress bar is displayed. (default: False) + :type silent: bool + :param kwargs: Additional keyword arguments to pass to the `srequest` function. + :type kwargs: dict + :returns: The filename of the downloaded file. + :rtype: str + """ + session = session or get_requests_session() + response = srequest(session, 'GET', url, stream=True, allow_redirects=True, **kwargs) + expected_size = expected_size or response.headers.get('Content-Length', None) + expected_size = int(expected_size) if expected_size is not None else expected_size + + desc = desc or os.path.basename(filename) + directory = os.path.dirname(filename) + if directory: + os.makedirs(directory, exist_ok=True) + + with open(filename, 'wb') as f: + with _with_tqdm(expected_size, desc, silent) as pbar: + for chunk in response.iter_content(chunk_size=1024): + f.write(chunk) + pbar.update(len(chunk)) + + actual_size = os.path.getsize(filename) + if expected_size is not None and actual_size != expected_size: + os.remove(filename) + raise requests.exceptions.HTTPError(f"Downloaded file is not of expected size, " + f"{expected_size} expected but {actual_size} found.") + + return filename diff --git a/cyberharem/utils/huggingface.py b/cyberharem/utils/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..e94b6f2ee701837ae7707e0df5b91cd48e3b39e0 --- /dev/null +++ b/cyberharem/utils/huggingface.py @@ -0,0 +1,40 @@ +import math +import os +from functools import partial + +from huggingface_hub import configure_http_backend, HfApi, HfFileSystem + +from .session import get_requests_session + +_NUM_TAGS = [ + ('n<1K', 0, 1_000), + ('1K1T', 1_000_000_000_000, math.inf), +] + + +def number_to_tag(v): + for tag, min_, max_ in _NUM_TAGS: + if min_ <= v < max_: + return tag + + raise ValueError(f'No tags found for {v!r}') + + +configure_http_backend(partial(get_requests_session, timeout=120)) + + +def get_hf_client() -> HfApi: + return HfApi(token=os.environ.get('HF_TOKEN')) + + +def get_hf_fs() -> HfFileSystem: + return HfFileSystem(token=os.environ.get('HF_TOKEN')) diff --git a/cyberharem/utils/session.py b/cyberharem/utils/session.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7d3c201f9af10755375bd8c2983acd7e4ee525 --- /dev/null +++ b/cyberharem/utils/session.py @@ -0,0 +1,149 @@ +import json +import logging +import os +import random +import time +from typing import Optional, Dict + +import requests +from huggingface_hub import hf_hub_download +from requests.adapters import HTTPAdapter, Retry +from requests.exceptions import RequestException + +DEFAULT_TIMEOUT = 30 # seconds + + +class TimeoutHTTPAdapter(HTTPAdapter): + """ + Custom HTTP adapter that sets a default timeout for requests. + + Inherits from `HTTPAdapter`. + + Usage: + - Create an instance of `TimeoutHTTPAdapter` and pass it to a `requests.Session` object's `mount` method. + + Example: + ```python + session = requests.Session() + adapter = TimeoutHTTPAdapter(timeout=10) + session.mount('http://', adapter) + session.mount('https://', adapter) + ``` + + :param timeout: The default timeout value in seconds. (default: 10) + :type timeout: int + """ + + def __init__(self, *args, **kwargs): + self.timeout = DEFAULT_TIMEOUT + if "timeout" in kwargs: + self.timeout = kwargs["timeout"] + del kwargs["timeout"] + super().__init__(*args, **kwargs) + + def send(self, request, **kwargs): + """ + Sends a request with the provided timeout value. + + :param request: The request to send. + :type request: PreparedRequest + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :returns: The response from the request. + :rtype: Response + """ + timeout = kwargs.get("timeout") + if timeout is None: + kwargs["timeout"] = self.timeout + return super().send(request, **kwargs) + + +def get_requests_session(max_retries: int = 5, timeout: int = DEFAULT_TIMEOUT, verify: bool = True, + headers: Optional[Dict[str, str]] = None, session: Optional[requests.Session] = None) \ + -> requests.Session: + """ + Returns a requests Session object configured with retry and timeout settings. + + :param max_retries: The maximum number of retries. (default: 5) + :type max_retries: int + :param timeout: The default timeout value in seconds. (default: 10) + :type timeout: int + :param headers: Additional headers to be added to the session. (default: None) + :type headers: Optional[Dict[str, str]] + :param session: An existing requests Session object to use. If not provided, a new Session object is created. (default: None) + :type session: Optional[requests.Session] + :returns: The requests Session object. + :rtype: requests.Session + """ + session = session or requests.session() + retries = Retry( + total=max_retries, backoff_factor=1, + status_forcelist=[408, 413, 429, 500, 501, 502, 503, 504, 505, 506, 507, 509, 510, 511], + allowed_methods=["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"], + ) + adapter = TimeoutHTTPAdapter(max_retries=retries, timeout=timeout, pool_connections=32, pool_maxsize=32) + session.mount('http://', adapter) + session.mount('https://', adapter) + session.headers.update({ + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 " + "(KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36", + **dict(headers or {}), + }) + if not verify: + session.verify = False + + return session + + +def get_civitai_session( + civitai_repository: str = 'narugo/civitai_session', + max_retries: int = 5, timeout: int = DEFAULT_TIMEOUT, verify: bool = True, + headers: Optional[Dict[str, str]] = None, session: Optional[requests.Session] = None) -> requests.Session: + session = get_requests_session(max_retries, timeout, verify, headers, session) + session_file = hf_hub_download(repo_id=civitai_repository, repo_type='dataset', + filename='session.json', token=os.environ['HF_TOKEN']) + with open(session_file, 'r', encoding='utf-8') as f: + session.cookies.update(json.load(f)['cookies']) + + return session + + +def srequest(session: requests.Session, method, url, *, max_retries: int = 5, + sleep_time: float = 5.0, raise_for_status: bool = True, **kwargs) -> requests.Response: + """ + Send a request using the provided session object with retry and timeout settings. + + :param session: The requests Session object to use for the request. + :type session: requests.Session + :param method: The HTTP method for the request. + :type method: str + :param url: The URL for the request. + :type url: str + :param max_retries: The maximum number of retries. (default: 5) + :type max_retries: int + :param sleep_time: The sleep time between retries in seconds. (default: 5.0) + :type sleep_time: float + :param raise_for_status: Whether to raise an exception for non-successful response status codes. (default: True) + :type raise_for_status: bool + :param kwargs: Additional keyword arguments for the request. + :type kwargs: dict + :returns: The response from the request. + :rtype: requests.Response + """ + if isinstance(session, (list, tuple)): + session = random.choice(session) + + resp = None + for _ in range(max_retries): + try: + resp = session.request(method, url, **kwargs) + except RequestException as err: + logging.error(f'Request error - {err!r}') + time.sleep(sleep_time) + else: + break + assert resp is not None, f'Request failed for {max_retries} time(s).' + if raise_for_status: + resp.raise_for_status() + + return resp diff --git a/cyberharem/utils/tags.py b/cyberharem/utils/tags.py new file mode 100644 index 0000000000000000000000000000000000000000..05b6ad2c8acc868b8770cae85bcd419e34414725 --- /dev/null +++ b/cyberharem/utils/tags.py @@ -0,0 +1,137 @@ +import glob +import logging +import os +import pathlib +import re +from typing import Mapping, List, Tuple, Union + +import numpy as np +from sklearn.cluster import OPTICS + +_GLOBAL_BLACKLISTED_WORDS = [ + 'text', 'signature', +] +_CORE_WORDS = [ + 'skin', 'eye', 'eyes', 'pupil', 'pupils', 'hair', 'horn', 'horns', 'ear', 'ears', 'neck', + 'breast', 'breasts', 'scar', 'scars', 'face', 'faces', 'blood', 'bleed', 'teeth', 'tooth', +] +_BLACKLISTED_WORDS = [ + 'solo', '1girl', '1boy', '2girls', '2boys', '3girls', '3boys', 'girls', 'boys', + 'body', 'background', 'quality', 'chibi', 'monochrome', 'comic', + 'dress', 'dresses', 'minidress', 'skirt', 'skirts', 'shoulder', 'shoulders', + 'slit', 'gown', 'sundress', 'sweater', 'wedding', 'socks', 'kneehighs', + 'thighhighs', 'pantyhose', 'legwear', 'trousers', 'shorts', + 'bra', 'pantsu', 'panty', 'panties', 'weapon', 'weapons', 'armor', + 'penis', 'pussy', 'vagina', 'clitoris', 'nipple', 'nipples', + 'looking', 'jacket', 'sleeves', 'clothes', 'shirt', 'hood', 'scarf', 'top', 'tops', + 'glove', 'gloves', 'mask', 'masks', 'coat', 'coats', 'frill', 'frills', + 'costume', 'costumes', 'pant', 'pants', 'clothing', 'clothes', 'cutout', + 'collar', 'collars', 'uniform', 'uniforms', 'trim', 'trims', 'neckerchief', 'choker', + 'kimono', 'holding', 'bunny', 'leotard', 'helmet', 'knee', 'pads', 'axe', 'boots', + 'peeking', 'focus', +] + + +def _contains_core_word(tag: str): + words = [word for word in re.split(r'[\W_]+', tag.lower()) if word] + return any(word in _CORE_WORDS for word in words) + + +def _contains_blacklisted_word(tag: str): + words = [word for word in re.split(r'[\W_]+', tag.lower()) if word] + return any((word in _BLACKLISTED_WORDS) or (word in _GLOBAL_BLACKLISTED_WORDS) for word in words) + + +def _contains_global_blacklisted_word(tag): + words = [word for word in re.split(r'[\W_]+', tag.lower()) if word] + return any(word in _GLOBAL_BLACKLISTED_WORDS for word in words) + + +def find_core_tags(tags: Mapping[str, float], core_threshold: float = 0.35, threshold: float = 0.45) \ + -> Mapping[str, float]: + retval = {} + for tag, score in sorted(tags.items(), key=lambda x: (-x[1], x[0])): + if _contains_blacklisted_word(tag): + continue + + if score >= threshold or (_contains_core_word(tag) and score >= core_threshold): + retval[tag] = score + + return retval + + +def load_tags_from_directory(directory: str, core_threshold: float = 0.35, threshold: float = 0.45) \ + -> Tuple[Mapping[str, float], List[Mapping[str, float]]]: + all_words = set() + ids_, word_lists = [], [] + for txt_file in glob.glob(os.path.join(directory, '*.txt')): + id_ = os.path.splitext(os.path.basename(txt_file))[0] + origin_text = pathlib.Path(txt_file).read_text().strip() + words = [word.strip() for word in re.split(r'\s*,\s*', origin_text) if word.strip()] + words = [word for word in words if not _contains_global_blacklisted_word(word)] + ids_.append(id_) + word_lists.append(words) + + for word in words: + all_words.add(word) + + all_words = sorted(all_words) + all_words_map = {word: i for i, word in enumerate(all_words)} + + features = [] + for words in word_lists: + feat = np.zeros((len(all_words),), dtype=float) + for word in words: + feat[all_words_map[word]] = 1.0 + features.append(feat) + + features = np.stack(features) + mf = features.mean(axis=0) + all_wds = { + word: value for word, value in + sorted(zip(all_words, mf.tolist()), key=lambda x: (-x[1], x[0])) + } + core_tags = find_core_tags(all_wds, core_threshold, threshold) + logging.info(f'Core tags found: {core_tags!r}.') + + cluster = OPTICS(metric='cosine', min_samples=5, xi=0.01) + cluster.fit(features) + mx = np.max(cluster.labels_).item() + + feats = [] + for i in range(1, mx + 1): + mean_feat = features[cluster.labels_ == i].mean(axis=0) + wds = { + word: value for word, value in + sorted(zip(all_words, mean_feat.tolist()), key=lambda x: (-x[1], x[0])) + if value >= threshold + } + pattern_tags = { + **{key: 1.0 + value for key, value in sorted(core_tags.items(), key=lambda x: -x[1])}, + **{key: value for key, value in wds.items() if key not in core_tags} + } + pattern_tags = {key: value for key, value in pattern_tags.items() if not _contains_global_blacklisted_word(key)} + feats.append(pattern_tags) + logging.info(f'Pattern {i} found: {pattern_tags!r}.') + + return core_tags, feats + + +def repr_tags(tags: List[Union[str, Tuple[str, float]]]) -> str: + _exists = set() + _str_items = [] + for item in tags: + if isinstance(item, tuple): + tag, weight = item + else: + tag, weight = item, None + if tag in _exists: + continue + + if weight is not None: + _str_items.append(f'{{{tag}:{weight:.2f}}}') + else: + _str_items.append(tag) + _exists.add(tag) + + return ', '.join(_str_items) diff --git a/cyberharem/utils/time.py b/cyberharem/utils/time.py new file mode 100644 index 0000000000000000000000000000000000000000..184315ae7a7ce2ab9ec1516809f078d220fb1509 --- /dev/null +++ b/cyberharem/utils/time.py @@ -0,0 +1,24 @@ +import datetime + +import dateparser + +try: + from zoneinfo import ZoneInfo +except (ImportError, ModuleNotFoundError): + from backports.zoneinfo import ZoneInfo + + +def parse_time(time): + if isinstance(time, str): + d = dateparser.parse(time) + elif isinstance(time, (int, float)): + d = datetime.datetime.fromtimestamp(time) + elif isinstance(time, datetime.datetime): + d = time + else: + raise TypeError(f'Unknown time type - {time}.') + + if not d.tzinfo: + d = d.astimezone() + + return d diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..548aa757063b4b91290ff72909ed6978f8a2d865 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +jupyterlab==3.6.1 +jupyter-server==2.3.0 +tornado==6.2 +gradio==4.7.1 +onnxruntime==1.15.0 +git+https://github.com/deepghs/waifuc.git@main#egg=waifuc +dghs-imgutils>=0.2.1 +git+https://github.com/narugo1992/gchar.git@main +pillow +numpy +scipy +scikit-learn +huggingface_hub>=0.14.0 +requests +click>=7.0.0 +hbutils>=0.9.1 +di-toolkit +tabulate +markdown2 +urlobject +blurhash +dateparser +backports.zoneinfo; python_version < '3.9' +git+https://github.com/narugo1992/pycivitai.git@main +natsort +pyquery +pandas +tqdm \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..a50f23b0156d7da7e0c06a63435723a175d694b7 --- /dev/null +++ b/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +PASSWORD="${PASSWORD:=huggingface}" +echo "Starting Jupyter Lab with token $PASSWORD" + +jupyter lab \ + --ip=0.0.0.0 \ + --port=7860 \ + --no-browser \ + --allow-root \ + --NotebookApp.token=$PASSWORD \ + --NotebookApp.tornado_settings="{'headers': {'Content-Security-Policy': 'frame-ancestors *'}}" \ + --NotebookApp.cookie_options="{'SameSite': 'None', 'Secure': True}" \ + --NotebookApp.disable_check_xsrf=True + diff --git a/start.sh b/start.sh new file mode 100644 index 0000000000000000000000000000000000000000..2aa6825d7e95aa91f99c9fdfde8323d8710250c7 --- /dev/null +++ b/start.sh @@ -0,0 +1,4 @@ +#!/bin/bash +export HF_TOKEN=$2 +export ONNX_MODE='CPUExecutionProvider' +python waifu_get.py --char $1 diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..71d688203f71a4911e184a07e353b703f61527c9 --- /dev/null +++ b/test.ipynb @@ -0,0 +1,67 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "e4b4f4a7-1514-4de7-8594-06b2611746ff", + "metadata": {}, + "outputs": [], + "source": [ + "mkdir cyberharem && mv cyberharem.zip cyberharem/ && cd cyberharem/ && unzip cyberharem.zip && rm -f cyberharem.zip && cd ../" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba8dc9eb-89f4-4b8a-a201-2f5cc0320543", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!python waifu_get.py --char 才羽モモイ --token token" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd32a787-6ef1-4b1d-9e8f-ef185c646cbf", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(os.environ.get('HF_TOKEN'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c4d476d-d68c-48eb-a97d-1aee027f2b32", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/waifu_get.py b/waifu_get.py new file mode 100644 index 0000000000000000000000000000000000000000..15c2e33280a548e9a4911920a2501eba0df5a603 --- /dev/null +++ b/waifu_get.py @@ -0,0 +1,32 @@ +import argparse +import os +from waifuc.action import HeadCountAction, AlignMinSizeAction, CCIPAction, ThreeStageSplitAction, ModeConvertAction, ClassFilterAction, PersonSplitAction, TaggingAction, RatingFilterAction, NoMonochromeAction, RandomFilenameAction, FirstNSelectAction, FilterSimilarAction, FileExtAction +from waifuc.export import SaveExporter, TextualInversionExporter +from waifuc.source import DanbooruSource, PixivSearchSource, ZerochanSource, LocalSource, GcharAutoSource +from cyberharem.dataset.crawler import crawl_dataset_to_huggingface + + +def main(): + os.environ['ONNX_MODE'] = 'CPUExecutionProvider' + parser = argparse.ArgumentParser() + parser.add_argument('--char', type=str, help='角色列表') + parser.add_argument('--artist', type=str, help='画师列表') + parser.add_argument('--token', type=str, help='token') + os.environ['HF_TOKEN'] = args.token + args = parser.parse_args() + if args.artist: + char_list = args.artist.split(',') + for ch in char_list: + crawl_dataset_to_huggingface(ch,DanbooruSource) + print(ch+"完成") + print("全部完成") + else: + char_list = args.char.split(',') + for ch in char_list: + crawl_dataset_to_huggingface(ch) + print(ch+"完成") + print("全部完成") + + +if __name__ == "__main__": + main() diff --git a/waifu_gr.py b/waifu_gr.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6c5c80334c255186fd3c777dd26c3d70e1ea2b --- /dev/null +++ b/waifu_gr.py @@ -0,0 +1,76 @@ +import gradio as gr +import os +import json +from waifuc.action import HeadCountAction, AlignMinSizeAction, CCIPAction, ThreeStageSplitAction, ModeConvertAction, ClassFilterAction, PersonSplitAction, TaggingAction, RatingFilterAction, NoMonochromeAction, RandomFilenameAction, FirstNSelectAction, FilterSimilarAction, FileExtAction +from waifuc.export import SaveExporter, TextualInversionExporter +from waifuc.source import DanbooruSource, PixivSearchSource, ZerochanSource, LocalSource, GcharAutoSource +from cyberharem.dataset.crawler import crawl_dataset_to_huggingface +from cyberharem.utils import get_hf_client, get_hf_fs +from hbutils.system import TemporaryDirectory +from cyberharem.utils import download_file as cyber_download_file +from huggingface_hub import hf_hub_url, hf_hub_download + + +def start_func(token, chars, is_cpu, udghs): + if not udghs: + if token: + os.environ['HF_TOKEN'] = token + else: + return "无令牌" + if is_cpu: + os.environ['ONNX_MODE'] = 'CPUExecutionProvider' + char_list = chars.split(',') + for ch in char_list: + crawl_dataset_to_huggingface(ch) + print(ch + "完成") + return str(chars)+" 上传完成" + else: + if token: + os.environ['HF_TOKEN'] = token + dgrepo = 'deepghs/game_characters' + else: + return "无令牌" + if is_cpu: + os.environ['ONNX_MODE'] = 'CPUExecutionProvider' + with TemporaryDirectory() as jsondir: + print("Downloading jsons..") + hf_fs = get_hf_fs() + _dgdatas = [file for file in hf_fs.glob(f'datasets/{dgrepo}/*/pixiv_characters.json')] + for name in _dgdatas: + os.makedirs(os.path.basename(os.path.dirname(name)), exist_ok=True) + # print(f'https://huggingface.co/{dgrepo}/blob/main/{os.path.basename(os.path.dirname(name))}/{os.path.basename(name)}') + js = hf_hub_download( + # f'https://huggingface.co/{dgrepo}/blob/main/{os.path.basename(os.path.dirname(name))}/{os.path.basename(name)}', + # hf_hub_url(dgrepo, filename=os.path.relpath(name, dgrepo)), + repo_id=dgrepo, repo_type='dataset', + # os.path.join(os.path.basename(os.path.dirname(name)), 'pixiv_characters.json'), + filename=os.path.join(os.path.basename(os.path.dirname(name)), 'pixiv_characters.json'), + token=os.environ['HF_TOKEN'] + ) + # with open(os.path.join(os.path.basename(os.path.dirname(name)), 'pixiv_characters.json'), 'r') as f: + with open(js, 'r', encoding='utf-8') as f: + jt = json.load(f) + chs = jt['characters'] + for jp in chs: + jp = jp['jpname'] + print(jp, 'start...') + crawl_dataset_to_huggingface(jp) + print(jp + "完成") + return "完成" + + +with gr.Blocks() as jblock: + hf_token = gr.Textbox(label="访问令牌", interactive=True) + char_list = gr.Textbox(label="角色列表", info="用,分隔", placeholder="《输入角色名然后你的数据集就出现在抱脸了》", interactive=True) + is_cpu = gr.Checkbox(label="无显卡", info="不使用显卡", value=True, interactive=True) + use_dghs = gr.Checkbox(label="从dghs", info="override", value=False, interactive=True) + start_button = gr.Button("开始上传", interactive=True) + opt_msg = gr.Textbox(interactive=False) + start_button.click(start_func, [hf_token, char_list, is_cpu, use_dghs], [opt_msg], api_name="crawlup") + +if __name__ == "__main__": + jblock.queue(max_size=64) + jblock.launch() + +# if __name__ == "__main__": +# jblock.launch(server_port=args.port)