import os import stat import sys import tempfile import re import torch import shutil import filecmp from abc import ABC, abstractmethod from functools import cached_property from .generic import ( BASE_PATH, Context, download_url_with_progressbar, prompt_yes_no, replace_prefix, get_digest, get_filename_from_url, ) from .log import get_logger class InfererModule(ABC): def __init__(self): self.logger = get_logger(self.__class__.__name__) super().__init__() def parse_args(self, args: Context): """May be overwritten by super classes to parse commandline arguments""" pass # class InfererModuleManager(ABC): # _KEY = '' # _VARIANTS = [] # def __init__(self): # self.onstart: Callable = None # self.onfinish: Callable = None # def validate(self): # """ # Throws exception if a # """ # ... # async def prepare(self): # ... # async def dispatch(self): # ... class ModelVerificationException(Exception): pass class InvalidModelMappingException(ValueError): def __init__(self, cls: str, map_key: str, error_msg: str): error = f'[{cls}->{map_key}] Invalid _MODEL_MAPPING - {error_msg}' super().__init__(error) class ModelWrapper(ABC): r""" A class that provides a unified interface for downloading models and making forward passes. All model inferer classes should extend it. Download specifications can be made through overwriting the `_MODEL_MAPPING` property. ```python _MODEL_MAPPTING = { 'model_id': { **PARAMETERS }, ... } ``` Parameters: model_id - Used for temporary caches and debug messages url - A direct download url hash - Hash of downloaded file, Can be obtained upon ModelVerificationException file - File download destination, If set to '.' the filename will be inferred from the url (fallback is `model_id` value) archive - Dict that contains all files/folders that are to be extracted from the downloaded archive and their destinations, Mutually exclusive with `file` executables - List of files that need to have the executable flag set """ _MODEL_DIR = os.path.join(BASE_PATH, 'models') _MODEL_SUB_DIR = '' _MODEL_MAPPING = {} _KEY = '' def __init__(self): os.makedirs(self.model_dir, exist_ok=True) self._key = self._KEY or self.__class__.__name__ self._loaded = False self._check_for_malformed_model_mapping() self._downloaded = self._check_downloaded() def is_loaded(self) -> bool: return self._loaded def is_downloaded(self) -> bool: return self._downloaded @property def model_dir(self): return os.path.join(self._MODEL_DIR, self._MODEL_SUB_DIR) def _get_file_path(self, *args) -> str: return os.path.join(self.model_dir, *args) def _get_used_gpu_memory(self) -> bool: ''' Gets the total amount of GPU memory used by model (Can be used in the future to determine whether a model should be loaded into vram or ram or automatically choose a model size). TODO: Use together with `--use-cuda-limited` flag to enforce stricter memory checks ''' return torch.cuda.mem_get_info() def _check_for_malformed_model_mapping(self): for map_key, mapping in self._MODEL_MAPPING.items(): if 'url' not in mapping: raise InvalidModelMappingException(self._key, map_key, 'Missing url property') elif not re.search(r'^https?://', mapping['url']): raise InvalidModelMappingException(self._key, map_key, 'Malformed url property: "%s"' % mapping['url']) if 'file' not in mapping and 'archive' not in mapping: mapping['file'] = '.' elif 'file' in mapping and 'archive' in mapping: raise InvalidModelMappingException(self._key, map_key, 'Properties file and archive are mutually exclusive') async def _download_file(self, url: str, path: str): print(f' -- Downloading: "{url}"') download_url_with_progressbar(url, path) async def _verify_file(self, sha256_pre_calculated: str, path: str): print(f' -- Verifying: "{path}"') sha256_calculated = get_digest(path).lower() sha256_pre_calculated = sha256_pre_calculated.lower() if sha256_calculated != sha256_pre_calculated: self._on_verify_failure(sha256_calculated, sha256_pre_calculated) else: print(' -- Verifying: OK!') def _on_verify_failure(self, sha256_calculated: str, sha256_pre_calculated: str): print(f' -- Mismatch between downloaded and created hash: "{sha256_calculated}" <-> "{sha256_pre_calculated}"') raise ModelVerificationException() @cached_property def _temp_working_directory(self): p = os.path.join(tempfile.gettempdir(), 'manga-image-translator', self._key.lower()) os.makedirs(p, exist_ok=True) return p async def download(self, force=False): ''' Downloads required models. ''' if force or not self.is_downloaded(): while True: try: await self._download() self._downloaded = True break except ModelVerificationException: if not prompt_yes_no('Failed to verify signature. Do you want to restart the download?', default=True): print('Aborting.', end='') raise KeyboardInterrupt() async def _download(self): ''' Downloads models as defined in `_MODEL_MAPPING`. Can be overwritten (together with `_check_downloaded`) to implement unconventional download logic. ''' print(f'\nDownloading models into {self.model_dir}\n') for map_key, mapping in self._MODEL_MAPPING.items(): if self._check_downloaded_map(map_key): print(f' -- Skipping {map_key} as it\'s already downloaded') continue is_archive = 'archive' in mapping if is_archive: download_path = os.path.join(self._temp_working_directory, map_key, '') else: download_path = self._get_file_path(mapping['file']) if not os.path.basename(download_path): os.makedirs(download_path, exist_ok=True) if os.path.basename(download_path) in ('', '.'): download_path = os.path.join(download_path, get_filename_from_url(mapping['url'], map_key)) if not is_archive: download_path += '.part' if 'hash' in mapping: downloaded = False if os.path.isfile(download_path): try: print(' -- Found existing file') await self._verify_file(mapping['hash'], download_path) downloaded = True except ModelVerificationException: print(' -- Resuming interrupted download') if not downloaded: await self._download_file(mapping['url'], download_path) await self._verify_file(mapping['hash'], download_path) else: await self._download_file(mapping['url'], download_path) if download_path.endswith('.part'): p = download_path[:len(download_path)-5] shutil.move(download_path, p) download_path = p if is_archive: extracted_path = os.path.join(os.path.dirname(download_path), 'extracted') print(f' -- Extracting files') shutil.unpack_archive(download_path, extracted_path) def get_real_archive_files(): archive_files = [] for root, dirs, files in os.walk(extracted_path): for name in files: file_path = replace_prefix(os.path.join(root, name), extracted_path, '') archive_files.append(file_path) return archive_files # Move every specified file from archive to destination for orig, dest in mapping['archive'].items(): p1 = os.path.join(extracted_path, orig) if os.path.exists(p1): p2 = self._get_file_path(dest) if os.path.basename(p2) in ('', '.'): p2 = os.path.join(p2, os.path.basename(p1)) if os.path.isfile(p2): if filecmp.cmp(p1, p2): continue raise InvalidModelMappingException(self._key, map_key, 'File "{orig}" already exists at "{dest}"') os.makedirs(os.path.dirname(p2), exist_ok=True) shutil.move(p1, p2) else: raise InvalidModelMappingException(self._key, map_key, f'File "{orig}" does not exist within archive' + '\nAvailable files:\n%s' % '\n'.join(get_real_archive_files())) if len(mapping['archive']) == 0: raise InvalidModelMappingException(self._key, map_key, 'No archive files specified' + '\nAvailable files:\n%s' % '\n'.join(get_real_archive_files())) self._grant_execute_permissions(map_key) # Remove temporary files try: os.remove(download_path) shutil.rmtree(extracted_path) except Exception: pass print() self._on_download_finished(map_key) def _on_download_finished(self, map_key): ''' Can be overwritten to further process the downloaded files ''' pass def _check_downloaded(self) -> bool: ''' Scans filesystem for required files as defined in `_MODEL_MAPPING`. Returns `False` if files should be redownloaded. ''' for map_key in self._MODEL_MAPPING: if not self._check_downloaded_map(map_key): return False return True def _check_downloaded_map(self, map_key: str) -> str: mapping = self._MODEL_MAPPING[map_key] if 'file' in mapping: path = mapping['file'] if os.path.basename(path) in ('.', ''): path = os.path.join(path, get_filename_from_url(mapping['url'], map_key)) if not os.path.exists(self._get_file_path(path)): return False elif 'archive' in mapping: for orig, dest in mapping['archive'].items(): if os.path.basename(dest) in ('', '.'): dest = os.path.join(dest, os.path.basename(orig[:-1] if orig.endswith('/') else orig)) if not os.path.exists(self._get_file_path(dest)): return False self._grant_execute_permissions(map_key) return True def _grant_execute_permissions(self, map_key: str): mapping = self._MODEL_MAPPING[map_key] if sys.platform == 'linux': # Grant permission to executables for file in mapping.get('executables', []): p = self._get_file_path(file) if os.path.basename(p) in ('', '.'): p = os.path.join(p, file) if not os.path.isfile(p): raise InvalidModelMappingException(self._key, map_key, f'File "{file}" does not exist') if not os.access(p, os.X_OK): os.chmod(p, os.stat(p).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) async def reload(self, device: str, *args, **kwargs): await self.unload() await self.load(*args, **kwargs, device=device) async def load(self, device: str, *args, **kwargs): ''' Loads models into memory. Has to be called before `forward`. ''' if not self.is_downloaded(): await self.download() if not self.is_loaded(): await self._load(*args, **kwargs, device=device) self._loaded = True async def unload(self): if self.is_loaded(): await self._unload() self._loaded = False async def infer(self, *args, **kwargs): ''' Makes a forward pass through the network. ''' if not self.is_loaded(): raise Exception(f'{self._key}: Tried to forward pass without having loaded the model.') return await self._infer(*args, **kwargs) @abstractmethod async def _load(self, device: str, *args, **kwargs): pass @abstractmethod async def _unload(self): pass @abstractmethod async def _infer(self, *args, **kwargs): pass