|
|
|
import timm |
|
import torch |
|
from PIL import Image |
|
from timm.utils import ParseKwargs |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT |
|
|
|
|
|
|
|
import os |
|
import time |
|
from contextlib import suppress |
|
from functools import partial |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
|
|
from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset |
|
from timm.layers import apply_test_time_pool |
|
from timm.models import create_model |
|
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs |
|
|
|
try: |
|
from apex import amp |
|
has_apex = True |
|
except ImportError: |
|
has_apex = False |
|
|
|
has_native_amp = False |
|
try: |
|
if getattr(torch.cuda.amp, 'autocast') is not None: |
|
has_native_amp = True |
|
except AttributeError: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
has_compile = hasattr(torch, 'compile') |
|
|
|
import PIL |
|
import requests |
|
import io |
|
import base64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
if torch.cuda.is_available(): |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' |
|
amp_dtype = torch.float16 |
|
amp_autocast = partial(torch.autocast, device_type=self.device.type, dtype=amp_dtype) |
|
|
|
|
|
|
|
|
|
self.aiArtModel = timm.create_model('eva02_base_patch14_448.mim_in22k_ft_in22k_in1k', num_classes=3, in_chans=3, checkpoint_path=path + 'AIArtDetector.pth-af59f7fa.pth') |
|
|
|
self.aiArtModel = self.aiArtModel.to(self.device) |
|
|
|
self.aiArtModel.eval() |
|
|
|
|
|
self.transform = timm.data.create_transform(input_size=(3, 448, 448), |
|
is_training=False, |
|
use_prefetcher=False, |
|
no_aug=False, |
|
scale=None, |
|
ratio=None, |
|
hflip=0, |
|
vflip=0., |
|
color_jitter=0, |
|
auto_augment=None, |
|
interpolation='bicubic', |
|
|
|
|
|
re_prob=0., |
|
re_mode='const', |
|
re_count=1, |
|
re_num_splits=0, |
|
crop_pct=1.0, |
|
|
|
crop_mode='squash', |
|
tf_preprocessing=False, |
|
separate=False) |
|
|
|
|
|
|
|
|
|
|
|
self.supported_formats = ["JPEG", "PNG", "BMP", "TIFF", "WEBP", "RAW"] |
|
print("initialized handler.py successfully") |
|
|
|
|
|
def __call__(self, data): |
|
""" |
|
data args: |
|
inputs: Dict[str, Any] |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
inputs = data.pop("inputs") |
|
if len(inputs) > 50: |
|
return {'error': 'Exceeds max limit of images (50)'} |
|
|
|
image_paths = inputs |
|
batch_size = 1 |
|
|
|
results = {} |
|
for i in range(0, len(image_paths), batch_size): |
|
|
|
batch_paths = image_paths[i:i+batch_size] |
|
validUrls = [] |
|
batch_images = [] |
|
|
|
for j, src in enumerate(batch_paths): |
|
try: |
|
|
|
pos = src.find("base64") |
|
if pos != -1: |
|
|
|
new = Image.open(io.BytesIO(base64.decodebytes(bytes(src[pos+7:], "utf-8")))).convert("RGB") |
|
|
|
batch_images.append(new) |
|
validUrls.append(src) |
|
else: |
|
try: |
|
|
|
|
|
|
|
|
|
headers = { |
|
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36' |
|
} |
|
|
|
r = requests.get(src, headers=headers) |
|
new = Image.open(io.BytesIO(r.content)).convert("RGB") |
|
|
|
batch_images.append(new) |
|
validUrls.append(src) |
|
except Exception as e: |
|
results[src] = {'error': 'Failed to process image'} |
|
|
|
continue |
|
|
|
|
|
except Exception as e: |
|
results[src] = {'error': 'Failed to process image w/ base64 in url'} |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_tensors = torch.stack([self.transform(img).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) for img in batch_images]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
output1 = self.aiGeneratorModel(batch_tensors) |
|
for k, tensor in enumerate(output1): |
|
output = tensor.softmax(-1) |
|
output, indice = output.topk(9) |
|
labels = [self.label_map[x] for x in indice.cpu().numpy().tolist()] |
|
probabilities = [round(i * 100, 2) for i in output.cpu().numpy().tolist()] |
|
single_res = {'prob': probabilities, 'indices': labels} |
|
results[validUrls[k]] = single_res |
|
|
|
return results |
|
|
|
|
|
|
|
|