File size: 13,335 Bytes
9dce458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import os
import stat
import sys
import tempfile
import re
import torch
import shutil
import filecmp
from abc import ABC, abstractmethod
from functools import cached_property

from .generic import (
    BASE_PATH,
    Context,
    download_url_with_progressbar,
    prompt_yes_no,
    replace_prefix,
    get_digest,
    get_filename_from_url,
)
from .log import get_logger


class InfererModule(ABC):
    def __init__(self):
        self.logger = get_logger(self.__class__.__name__)
        super().__init__()

    def parse_args(self, args: Context):
        """May be overwritten by super classes to parse commandline arguments"""
        pass

# class InfererModuleManager(ABC):
#     _KEY = ''
#     _VARIANTS = []

#     def __init__(self):
#         self.onstart: Callable = None
#         self.onfinish: Callable = None

#     def validate(self):
#         """
#         Throws exception if a 
#         """
#         ...

#     async def prepare(self):
#         ...

#     async def dispatch(self):
#         ...


class ModelVerificationException(Exception):
    pass

class InvalidModelMappingException(ValueError):
    def __init__(self, cls: str, map_key: str, error_msg: str):
        error = f'[{cls}->{map_key}] Invalid _MODEL_MAPPING - {error_msg}'
        super().__init__(error)

class ModelWrapper(ABC):
    r"""
    A class that provides a unified interface for downloading models and making forward passes.
    All model inferer classes should extend it.

    Download specifications can be made through overwriting the `_MODEL_MAPPING` property.

    ```python
    _MODEL_MAPPTING = {
        'model_id': {
            **PARAMETERS
        },
        ...
    }
    ```

    Parameters:

    model_id            - Used for temporary caches and debug messages

        url                 - A direct download url

        hash                - Hash of downloaded file, Can be obtained upon ModelVerificationException

        file                - File download destination, If set to '.' the filename will be inferred
                              from the url (fallback is `model_id` value)

        archive             - Dict that contains all files/folders that are to be extracted from
                              the downloaded archive and their destinations, Mutually exclusive with `file`

        executables         - List of files that need to have the executable flag set
    """
    _MODEL_DIR = os.path.join(BASE_PATH, 'models')
    _MODEL_SUB_DIR = ''
    _MODEL_MAPPING = {}
    _KEY = ''

    def __init__(self):
        os.makedirs(self.model_dir, exist_ok=True)
        self._key = self._KEY or self.__class__.__name__
        self._loaded = False
        self._check_for_malformed_model_mapping()
        self._downloaded = self._check_downloaded()

    def is_loaded(self) -> bool:
        return self._loaded

    def is_downloaded(self) -> bool:
        return self._downloaded

    @property
    def model_dir(self):
        return os.path.join(self._MODEL_DIR, self._MODEL_SUB_DIR)

    def _get_file_path(self, *args) -> str:
        return os.path.join(self.model_dir, *args)

    def _get_used_gpu_memory(self) -> bool:
        '''
        Gets the total amount of GPU memory used by model (Can be used in the future
        to determine whether a model should be loaded into vram or ram or automatically choose a model size).
        TODO: Use together with `--use-cuda-limited` flag to enforce stricter memory checks
        '''
        return torch.cuda.mem_get_info()

    def _check_for_malformed_model_mapping(self):
        for map_key, mapping in self._MODEL_MAPPING.items():
            if 'url' not in mapping:
                raise InvalidModelMappingException(self._key, map_key, 'Missing url property')
            elif not re.search(r'^https?://', mapping['url']):
                raise InvalidModelMappingException(self._key, map_key, 'Malformed url property: "%s"' % mapping['url'])
            if 'file' not in mapping and 'archive' not in mapping:
                mapping['file'] = '.'
            elif 'file' in mapping and 'archive' in mapping:
                raise InvalidModelMappingException(self._key, map_key, 'Properties file and archive are mutually exclusive')

    async def _download_file(self, url: str, path: str):
        print(f' -- Downloading: "{url}"')
        download_url_with_progressbar(url, path)

    async def _verify_file(self, sha256_pre_calculated: str, path: str):
        print(f' -- Verifying: "{path}"')
        sha256_calculated = get_digest(path).lower()
        sha256_pre_calculated = sha256_pre_calculated.lower()

        if sha256_calculated != sha256_pre_calculated:
            self._on_verify_failure(sha256_calculated, sha256_pre_calculated)
        else:
            print(' -- Verifying: OK!')

    def _on_verify_failure(self, sha256_calculated: str, sha256_pre_calculated: str):
        print(f' -- Mismatch between downloaded and created hash: "{sha256_calculated}" <-> "{sha256_pre_calculated}"')
        raise ModelVerificationException()

    @cached_property
    def _temp_working_directory(self):
        p = os.path.join(tempfile.gettempdir(), 'manga-image-translator', self._key.lower())
        os.makedirs(p, exist_ok=True)
        return p

    async def download(self, force=False):
        '''
        Downloads required models.
        '''
        if force or not self.is_downloaded():
            while True:
                try:
                    await self._download()
                    self._downloaded = True
                    break
                except ModelVerificationException:
                    if not prompt_yes_no('Failed to verify signature. Do you want to restart the download?', default=True):
                        print('Aborting.', end='')
                        raise KeyboardInterrupt()

    async def _download(self):
        '''
        Downloads models as defined in `_MODEL_MAPPING`. Can be overwritten (together
        with `_check_downloaded`) to implement unconventional download logic.
        '''
        print(f'\nDownloading models into {self.model_dir}\n')
        for map_key, mapping in self._MODEL_MAPPING.items():
            if self._check_downloaded_map(map_key):
                print(f' -- Skipping {map_key} as it\'s already downloaded')
                continue

            is_archive = 'archive' in mapping
            if is_archive:
                download_path = os.path.join(self._temp_working_directory, map_key, '')
            else:
                download_path = self._get_file_path(mapping['file'])
            if not os.path.basename(download_path):
                os.makedirs(download_path, exist_ok=True)
            if os.path.basename(download_path) in ('', '.'):
                download_path = os.path.join(download_path, get_filename_from_url(mapping['url'], map_key))
            if not is_archive:
                download_path += '.part'

            if 'hash' in mapping:
                downloaded = False
                if os.path.isfile(download_path):
                    try:
                        print(' -- Found existing file')
                        await self._verify_file(mapping['hash'], download_path)
                        downloaded = True
                    except ModelVerificationException:
                        print(' -- Resuming interrupted download')
                if not downloaded:
                    await self._download_file(mapping['url'], download_path)
                    await self._verify_file(mapping['hash'], download_path)
            else:
                await self._download_file(mapping['url'], download_path)

            if download_path.endswith('.part'):
                p = download_path[:len(download_path)-5]
                shutil.move(download_path, p)
                download_path = p

            if is_archive:
                extracted_path = os.path.join(os.path.dirname(download_path), 'extracted')
                print(f' -- Extracting files')
                shutil.unpack_archive(download_path, extracted_path)

                def get_real_archive_files():
                    archive_files = []
                    for root, dirs, files in os.walk(extracted_path):
                        for name in files:
                            file_path = replace_prefix(os.path.join(root, name), extracted_path, '')
                            archive_files.append(file_path)
                    return archive_files

                # Move every specified file from archive to destination
                for orig, dest in mapping['archive'].items():
                    p1 = os.path.join(extracted_path, orig)
                    if os.path.exists(p1):
                        p2 = self._get_file_path(dest)
                        if os.path.basename(p2) in ('', '.'):
                            p2 = os.path.join(p2, os.path.basename(p1))
                        if os.path.isfile(p2):
                            if filecmp.cmp(p1, p2):
                                continue
                            raise InvalidModelMappingException(self._key, map_key, 'File "{orig}" already exists at "{dest}"')
                        os.makedirs(os.path.dirname(p2), exist_ok=True)
                        shutil.move(p1, p2)
                    else:
                        raise InvalidModelMappingException(self._key, map_key, f'File "{orig}" does not exist within archive' +
                                    '\nAvailable files:\n%s' % '\n'.join(get_real_archive_files()))
                if len(mapping['archive']) == 0:
                    raise InvalidModelMappingException(self._key, map_key, 'No archive files specified' +
                                        '\nAvailable files:\n%s' % '\n'.join(get_real_archive_files()))

                self._grant_execute_permissions(map_key)

                # Remove temporary files
                try:
                    os.remove(download_path)
                    shutil.rmtree(extracted_path)
                except Exception:
                    pass

            print()
            self._on_download_finished(map_key)

    def _on_download_finished(self, map_key):
        '''
        Can be overwritten to further process the downloaded files
        '''
        pass

    def _check_downloaded(self) -> bool:
        '''
        Scans filesystem for required files as defined in `_MODEL_MAPPING`.
        Returns `False` if files should be redownloaded.
        '''
        for map_key in self._MODEL_MAPPING:
            if not self._check_downloaded_map(map_key):
                return False
        return True

    def _check_downloaded_map(self, map_key: str) -> str:
        mapping = self._MODEL_MAPPING[map_key]

        if 'file' in mapping:
            path = mapping['file']
            if os.path.basename(path) in ('.', ''):
                path = os.path.join(path, get_filename_from_url(mapping['url'], map_key))
            if not os.path.exists(self._get_file_path(path)):
                return False

        elif 'archive' in mapping:
            for orig, dest in mapping['archive'].items():
                if os.path.basename(dest) in ('', '.'):
                    dest = os.path.join(dest, os.path.basename(orig[:-1] if orig.endswith('/') else orig))
                if not os.path.exists(self._get_file_path(dest)):
                    return False

        self._grant_execute_permissions(map_key)

        return True

    def _grant_execute_permissions(self, map_key: str):
        mapping = self._MODEL_MAPPING[map_key]

        if sys.platform == 'linux':
            # Grant permission to executables
            for file in mapping.get('executables', []):
                p = self._get_file_path(file)
                if os.path.basename(p) in ('', '.'):
                    p = os.path.join(p, file)
                if not os.path.isfile(p):
                    raise InvalidModelMappingException(self._key, map_key, f'File "{file}" does not exist')
                if not os.access(p, os.X_OK):
                    os.chmod(p, os.stat(p).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)

    async def reload(self, device: str, *args, **kwargs):
        await self.unload()
        await self.load(*args, **kwargs, device=device)

    async def load(self, device: str, *args, **kwargs):
        '''
        Loads models into memory. Has to be called before `forward`.
        '''
        if not self.is_downloaded():
            await self.download()
        if not self.is_loaded():
            await self._load(*args, **kwargs, device=device)
            self._loaded = True

    async def unload(self):
        if self.is_loaded():
            await self._unload()
            self._loaded = False

    async def infer(self, *args, **kwargs):
        '''
        Makes a forward pass through the network.
        '''
        if not self.is_loaded():
            raise Exception(f'{self._key}: Tried to forward pass without having loaded the model.')
        
        return await self._infer(*args, **kwargs)

    @abstractmethod
    async def _load(self, device: str, *args, **kwargs):
        pass

    @abstractmethod
    async def _unload(self):
        pass

    @abstractmethod
    async def _infer(self, *args, **kwargs):
        pass