|
import os |
|
import stat |
|
import sys |
|
import tempfile |
|
import re |
|
import torch |
|
import shutil |
|
import filecmp |
|
from abc import ABC, abstractmethod |
|
from functools import cached_property |
|
|
|
from .generic import ( |
|
BASE_PATH, |
|
Context, |
|
download_url_with_progressbar, |
|
prompt_yes_no, |
|
replace_prefix, |
|
get_digest, |
|
get_filename_from_url, |
|
) |
|
from .log import get_logger |
|
|
|
|
|
class InfererModule(ABC): |
|
def __init__(self): |
|
self.logger = get_logger(self.__class__.__name__) |
|
super().__init__() |
|
|
|
def parse_args(self, args: Context): |
|
"""May be overwritten by super classes to parse commandline arguments""" |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelVerificationException(Exception): |
|
pass |
|
|
|
class InvalidModelMappingException(ValueError): |
|
def __init__(self, cls: str, map_key: str, error_msg: str): |
|
error = f'[{cls}->{map_key}] Invalid _MODEL_MAPPING - {error_msg}' |
|
super().__init__(error) |
|
|
|
class ModelWrapper(ABC): |
|
r""" |
|
A class that provides a unified interface for downloading models and making forward passes. |
|
All model inferer classes should extend it. |
|
|
|
Download specifications can be made through overwriting the `_MODEL_MAPPING` property. |
|
|
|
```python |
|
_MODEL_MAPPTING = { |
|
'model_id': { |
|
**PARAMETERS |
|
}, |
|
... |
|
} |
|
``` |
|
|
|
Parameters: |
|
|
|
model_id - Used for temporary caches and debug messages |
|
|
|
url - A direct download url |
|
|
|
hash - Hash of downloaded file, Can be obtained upon ModelVerificationException |
|
|
|
file - File download destination, If set to '.' the filename will be inferred |
|
from the url (fallback is `model_id` value) |
|
|
|
archive - Dict that contains all files/folders that are to be extracted from |
|
the downloaded archive and their destinations, Mutually exclusive with `file` |
|
|
|
executables - List of files that need to have the executable flag set |
|
""" |
|
_MODEL_DIR = os.path.join(BASE_PATH, 'models') |
|
_MODEL_SUB_DIR = '' |
|
_MODEL_MAPPING = {} |
|
_KEY = '' |
|
|
|
def __init__(self): |
|
os.makedirs(self.model_dir, exist_ok=True) |
|
self._key = self._KEY or self.__class__.__name__ |
|
self._loaded = False |
|
self._check_for_malformed_model_mapping() |
|
self._downloaded = self._check_downloaded() |
|
|
|
def is_loaded(self) -> bool: |
|
return self._loaded |
|
|
|
def is_downloaded(self) -> bool: |
|
return self._downloaded |
|
|
|
@property |
|
def model_dir(self): |
|
return os.path.join(self._MODEL_DIR, self._MODEL_SUB_DIR) |
|
|
|
def _get_file_path(self, *args) -> str: |
|
return os.path.join(self.model_dir, *args) |
|
|
|
def _get_used_gpu_memory(self) -> bool: |
|
''' |
|
Gets the total amount of GPU memory used by model (Can be used in the future |
|
to determine whether a model should be loaded into vram or ram or automatically choose a model size). |
|
TODO: Use together with `--use-cuda-limited` flag to enforce stricter memory checks |
|
''' |
|
return torch.cuda.mem_get_info() |
|
|
|
def _check_for_malformed_model_mapping(self): |
|
for map_key, mapping in self._MODEL_MAPPING.items(): |
|
if 'url' not in mapping: |
|
raise InvalidModelMappingException(self._key, map_key, 'Missing url property') |
|
elif not re.search(r'^https?://', mapping['url']): |
|
raise InvalidModelMappingException(self._key, map_key, 'Malformed url property: "%s"' % mapping['url']) |
|
if 'file' not in mapping and 'archive' not in mapping: |
|
mapping['file'] = '.' |
|
elif 'file' in mapping and 'archive' in mapping: |
|
raise InvalidModelMappingException(self._key, map_key, 'Properties file and archive are mutually exclusive') |
|
|
|
async def _download_file(self, url: str, path: str): |
|
print(f' -- Downloading: "{url}"') |
|
download_url_with_progressbar(url, path) |
|
|
|
async def _verify_file(self, sha256_pre_calculated: str, path: str): |
|
print(f' -- Verifying: "{path}"') |
|
sha256_calculated = get_digest(path).lower() |
|
sha256_pre_calculated = sha256_pre_calculated.lower() |
|
|
|
if sha256_calculated != sha256_pre_calculated: |
|
self._on_verify_failure(sha256_calculated, sha256_pre_calculated) |
|
else: |
|
print(' -- Verifying: OK!') |
|
|
|
def _on_verify_failure(self, sha256_calculated: str, sha256_pre_calculated: str): |
|
print(f' -- Mismatch between downloaded and created hash: "{sha256_calculated}" <-> "{sha256_pre_calculated}"') |
|
raise ModelVerificationException() |
|
|
|
@cached_property |
|
def _temp_working_directory(self): |
|
p = os.path.join(tempfile.gettempdir(), 'manga-image-translator', self._key.lower()) |
|
os.makedirs(p, exist_ok=True) |
|
return p |
|
|
|
async def download(self, force=False): |
|
''' |
|
Downloads required models. |
|
''' |
|
if force or not self.is_downloaded(): |
|
while True: |
|
try: |
|
await self._download() |
|
self._downloaded = True |
|
break |
|
except ModelVerificationException: |
|
if not prompt_yes_no('Failed to verify signature. Do you want to restart the download?', default=True): |
|
print('Aborting.', end='') |
|
raise KeyboardInterrupt() |
|
|
|
async def _download(self): |
|
''' |
|
Downloads models as defined in `_MODEL_MAPPING`. Can be overwritten (together |
|
with `_check_downloaded`) to implement unconventional download logic. |
|
''' |
|
print(f'\nDownloading models into {self.model_dir}\n') |
|
for map_key, mapping in self._MODEL_MAPPING.items(): |
|
if self._check_downloaded_map(map_key): |
|
print(f' -- Skipping {map_key} as it\'s already downloaded') |
|
continue |
|
|
|
is_archive = 'archive' in mapping |
|
if is_archive: |
|
download_path = os.path.join(self._temp_working_directory, map_key, '') |
|
else: |
|
download_path = self._get_file_path(mapping['file']) |
|
if not os.path.basename(download_path): |
|
os.makedirs(download_path, exist_ok=True) |
|
if os.path.basename(download_path) in ('', '.'): |
|
download_path = os.path.join(download_path, get_filename_from_url(mapping['url'], map_key)) |
|
if not is_archive: |
|
download_path += '.part' |
|
|
|
if 'hash' in mapping: |
|
downloaded = False |
|
if os.path.isfile(download_path): |
|
try: |
|
print(' -- Found existing file') |
|
await self._verify_file(mapping['hash'], download_path) |
|
downloaded = True |
|
except ModelVerificationException: |
|
print(' -- Resuming interrupted download') |
|
if not downloaded: |
|
await self._download_file(mapping['url'], download_path) |
|
await self._verify_file(mapping['hash'], download_path) |
|
else: |
|
await self._download_file(mapping['url'], download_path) |
|
|
|
if download_path.endswith('.part'): |
|
p = download_path[:len(download_path)-5] |
|
shutil.move(download_path, p) |
|
download_path = p |
|
|
|
if is_archive: |
|
extracted_path = os.path.join(os.path.dirname(download_path), 'extracted') |
|
print(f' -- Extracting files') |
|
shutil.unpack_archive(download_path, extracted_path) |
|
|
|
def get_real_archive_files(): |
|
archive_files = [] |
|
for root, dirs, files in os.walk(extracted_path): |
|
for name in files: |
|
file_path = replace_prefix(os.path.join(root, name), extracted_path, '') |
|
archive_files.append(file_path) |
|
return archive_files |
|
|
|
|
|
for orig, dest in mapping['archive'].items(): |
|
p1 = os.path.join(extracted_path, orig) |
|
if os.path.exists(p1): |
|
p2 = self._get_file_path(dest) |
|
if os.path.basename(p2) in ('', '.'): |
|
p2 = os.path.join(p2, os.path.basename(p1)) |
|
if os.path.isfile(p2): |
|
if filecmp.cmp(p1, p2): |
|
continue |
|
raise InvalidModelMappingException(self._key, map_key, 'File "{orig}" already exists at "{dest}"') |
|
os.makedirs(os.path.dirname(p2), exist_ok=True) |
|
shutil.move(p1, p2) |
|
else: |
|
raise InvalidModelMappingException(self._key, map_key, f'File "{orig}" does not exist within archive' + |
|
'\nAvailable files:\n%s' % '\n'.join(get_real_archive_files())) |
|
if len(mapping['archive']) == 0: |
|
raise InvalidModelMappingException(self._key, map_key, 'No archive files specified' + |
|
'\nAvailable files:\n%s' % '\n'.join(get_real_archive_files())) |
|
|
|
self._grant_execute_permissions(map_key) |
|
|
|
|
|
try: |
|
os.remove(download_path) |
|
shutil.rmtree(extracted_path) |
|
except Exception: |
|
pass |
|
|
|
print() |
|
self._on_download_finished(map_key) |
|
|
|
def _on_download_finished(self, map_key): |
|
''' |
|
Can be overwritten to further process the downloaded files |
|
''' |
|
pass |
|
|
|
def _check_downloaded(self) -> bool: |
|
''' |
|
Scans filesystem for required files as defined in `_MODEL_MAPPING`. |
|
Returns `False` if files should be redownloaded. |
|
''' |
|
for map_key in self._MODEL_MAPPING: |
|
if not self._check_downloaded_map(map_key): |
|
return False |
|
return True |
|
|
|
def _check_downloaded_map(self, map_key: str) -> str: |
|
mapping = self._MODEL_MAPPING[map_key] |
|
|
|
if 'file' in mapping: |
|
path = mapping['file'] |
|
if os.path.basename(path) in ('.', ''): |
|
path = os.path.join(path, get_filename_from_url(mapping['url'], map_key)) |
|
if not os.path.exists(self._get_file_path(path)): |
|
return False |
|
|
|
elif 'archive' in mapping: |
|
for orig, dest in mapping['archive'].items(): |
|
if os.path.basename(dest) in ('', '.'): |
|
dest = os.path.join(dest, os.path.basename(orig[:-1] if orig.endswith('/') else orig)) |
|
if not os.path.exists(self._get_file_path(dest)): |
|
return False |
|
|
|
self._grant_execute_permissions(map_key) |
|
|
|
return True |
|
|
|
def _grant_execute_permissions(self, map_key: str): |
|
mapping = self._MODEL_MAPPING[map_key] |
|
|
|
if sys.platform == 'linux': |
|
|
|
for file in mapping.get('executables', []): |
|
p = self._get_file_path(file) |
|
if os.path.basename(p) in ('', '.'): |
|
p = os.path.join(p, file) |
|
if not os.path.isfile(p): |
|
raise InvalidModelMappingException(self._key, map_key, f'File "{file}" does not exist') |
|
if not os.access(p, os.X_OK): |
|
os.chmod(p, os.stat(p).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) |
|
|
|
async def reload(self, device: str, *args, **kwargs): |
|
await self.unload() |
|
await self.load(*args, **kwargs, device=device) |
|
|
|
async def load(self, device: str, *args, **kwargs): |
|
''' |
|
Loads models into memory. Has to be called before `forward`. |
|
''' |
|
if not self.is_downloaded(): |
|
await self.download() |
|
if not self.is_loaded(): |
|
await self._load(*args, **kwargs, device=device) |
|
self._loaded = True |
|
|
|
async def unload(self): |
|
if self.is_loaded(): |
|
await self._unload() |
|
self._loaded = False |
|
|
|
async def infer(self, *args, **kwargs): |
|
''' |
|
Makes a forward pass through the network. |
|
''' |
|
if not self.is_loaded(): |
|
raise Exception(f'{self._key}: Tried to forward pass without having loaded the model.') |
|
|
|
return await self._infer(*args, **kwargs) |
|
|
|
@abstractmethod |
|
async def _load(self, device: str, *args, **kwargs): |
|
pass |
|
|
|
@abstractmethod |
|
async def _unload(self): |
|
pass |
|
|
|
@abstractmethod |
|
async def _infer(self, *args, **kwargs): |
|
pass |
|
|