File size: 2,093 Bytes
9dce458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from PIL import Image
from typing import List
from abc import abstractmethod

from ..utils import InfererModule, ModelWrapper

class CommonUpscaler(InfererModule):
    _VALID_UPSCALE_RATIOS = []

    async def upscale(self, image_batch: List[Image.Image], upscale_ratio: float) -> List[Image.Image]:
        if upscale_ratio == 1:
            return image_batch

        self._VALID_UPSCALE_RATIOS.sort()
        assert(self._VALID_UPSCALE_RATIOS[0] > 1)

        ratio_left = upscale_ratio
        while ratio_left > 0:
            ratio = self._VALID_UPSCALE_RATIOS[-1]
            for valid_ratio in self._VALID_UPSCALE_RATIOS:
                if ratio_left <= valid_ratio:
                    ratio = valid_ratio
                    break
            ratio_left -= ratio
            if upscale_ratio > self._VALID_UPSCALE_RATIOS[-1]:
                self.logger.info(f'Upscaling image by {ratio}; left: {ratio_left}')
            image_batch = await self._upscale(image_batch, ratio)
        if ratio_left < 0:
            downscale_ratio = (ratio + ratio_left) / ratio
            self.logger.info(f'Downscaling image by {downscale_ratio} to correct upscale ratio')
            for i, image in enumerate(image_batch):
                image_batch[i] = image.resize((int(image.size[0] * downscale_ratio), int(image.size[1] * downscale_ratio)))
        return image_batch

    @abstractmethod
    async def _upscale(self, image_batch: List[Image.Image], upscale_ratio: float) -> List[Image.Image]:
        pass

class OfflineUpscaler(CommonUpscaler, ModelWrapper):
    _MODEL_SUB_DIR = 'upscaling'

    async def _upscale(self, *args, **kwargs):
        return await self.infer(*args, **kwargs)

    @abstractmethod
    async def _infer(self, image_batch: List[Image.Image], upscale_ratio: float) -> List[Image.Image]:
        """
        Perform the actual upscaling of the images.

        Args:
            image_batch: The list of images to upscale.
            upscale_ratio: The upscale ratio to use.

        Returns:
            The list of upscaled images.
        """
        pass