|
from typing import List |
|
from PIL import Image |
|
|
|
from .common import CommonUpscaler, OfflineUpscaler |
|
from .waifu2x import Waifu2xUpscaler |
|
from .esrgan import ESRGANUpscaler |
|
from .esrgan_pytorch import ESRGANUpscalerPytorch |
|
|
|
UPSCALERS = { |
|
'waifu2x': Waifu2xUpscaler, |
|
'esrgan': ESRGANUpscaler, |
|
'4xultrasharp': ESRGANUpscalerPytorch, |
|
} |
|
upscaler_cache = {} |
|
|
|
def get_upscaler(key: str, *args, **kwargs) -> CommonUpscaler: |
|
if key not in UPSCALERS: |
|
raise ValueError(f'Could not find upscaler for: "{key}". Choose from the following: %s' % ','.join(UPSCALERS)) |
|
if not upscaler_cache.get(key): |
|
upscaler = UPSCALERS[key] |
|
upscaler_cache[key] = upscaler(*args, **kwargs) |
|
return upscaler_cache[key] |
|
|
|
async def prepare(upscaler_key: str): |
|
upscaler = get_upscaler(upscaler_key) |
|
if isinstance(upscaler, OfflineUpscaler): |
|
await upscaler.download() |
|
|
|
async def dispatch(upscaler_key: str, image_batch: List[Image.Image], upscale_ratio: int, device: str = 'cpu') -> List[Image.Image]: |
|
if upscale_ratio == 1: |
|
return image_batch |
|
upscaler = get_upscaler(upscaler_key) |
|
if isinstance(upscaler, OfflineUpscaler): |
|
await upscaler.load(device) |
|
return await upscaler.upscale(image_batch, upscale_ratio) |
|
|