File size: 2,451 Bytes
8cb8f64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85cd204
8cb8f64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import math
import os
import random
from functools import lru_cache

import numpy as np
from PIL import Image
from huggingface_hub import HfFileSystem, HfApi
from imgutils.utils import open_onnx_model
from natsort import natsorted

hf_token = os.environ.get('HF_TOKEN')
hf_fs = HfFileSystem(token=hf_token)
hf_client = HfApi(token=hf_token)

REPOSITORY = 'mf666/shit-checker'
MODELS = natsorted([
    os.path.splitext(os.path.relpath(file, REPOSITORY))[0]
    for file in hf_fs.glob(f'{REPOSITORY}/*.onnx')
])
DEFAULT_MODEL = 'mobilenet.xs.v2'


@lru_cache()
def _open_model(model_name):
    return open_onnx_model(hf_client.hf_hub_download(REPOSITORY, f'{model_name}.onnx'))


_DEFAULT_ORDER = 'HWC'


def _get_hwc_map(order_):
    return tuple(_DEFAULT_ORDER.index(c) for c in order_.upper())


def _encode_channels(image, channels_order='CHW', is_float=True):
    array = np.asarray(image.convert('RGB'))
    array = np.transpose(array, _get_hwc_map(channels_order))
    if not is_float:
        assert array.dtype == np.uint8
    else:
        array = (array / 255.0).astype(np.float32)
        assert array.dtype == np.float32
    return array


def _img_encode(image, size=(384, 384), normalize=(0.5, 0.5)):
    image = image.resize(size, Image.BILINEAR)
    data = _encode_channels(image, channels_order='CHW')

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std

    return data.astype(np.float32)


def _raw_predict(images, model_name=DEFAULT_MODEL):
    items = []
    for image in images:
        items.append(_img_encode(image.convert('RGB')))
    input_ = np.stack(items)
    output, = _open_model(model_name).run(['output'], {'input': input_})
    return output.mean(axis=0)


def predict(image, model_name=DEFAULT_MODEL, max_batch_size=8):
    area = image.width * image.height
    batch_size = int(max(min(math.ceil(area / (384 * 384)) + 1, max_batch_size), 1))
    blocks = []
    for _ in range(batch_size):
        x0 = random.randint(0, max(0, image.width - 384))
        y0 = random.randint(0, max(0, image.height - 384))
        x1 = min(x0 + 384, image.width)
        y1 = min(y0 + 384, image.height)
        blocks.append(image.crop((x0, y0, x1, y1)))

    scores = _raw_predict(blocks, model_name)
    return dict(zip(['shat', 'normal'], map(lambda x: x.item(), scores)))