Sunday01's picture
up
9dce458
from typing import List
from PIL import Image
from .common import CommonUpscaler, OfflineUpscaler
from .waifu2x import Waifu2xUpscaler
from .esrgan import ESRGANUpscaler
from .esrgan_pytorch import ESRGANUpscalerPytorch
UPSCALERS = {
'waifu2x': Waifu2xUpscaler,
'esrgan': ESRGANUpscaler,
'4xultrasharp': ESRGANUpscalerPytorch,
}
upscaler_cache = {}
def get_upscaler(key: str, *args, **kwargs) -> CommonUpscaler:
if key not in UPSCALERS:
raise ValueError(f'Could not find upscaler for: "{key}". Choose from the following: %s' % ','.join(UPSCALERS))
if not upscaler_cache.get(key):
upscaler = UPSCALERS[key]
upscaler_cache[key] = upscaler(*args, **kwargs)
return upscaler_cache[key]
async def prepare(upscaler_key: str):
upscaler = get_upscaler(upscaler_key)
if isinstance(upscaler, OfflineUpscaler):
await upscaler.download()
async def dispatch(upscaler_key: str, image_batch: List[Image.Image], upscale_ratio: int, device: str = 'cpu') -> List[Image.Image]:
if upscale_ratio == 1:
return image_batch
upscaler = get_upscaler(upscaler_key)
if isinstance(upscaler, OfflineUpscaler):
await upscaler.load(device)
return await upscaler.upscale(image_batch, upscale_ratio)