eva02-ai-art-detector / handler.py
Noah-Wang's picture
Update handler.py
6737fe8 verified
import timm
import torch
from PIL import Image
from timm.utils import ParseKwargs
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
###
import os
import time
from contextlib import suppress
from functools import partial
import numpy as np
import pandas as pd
import torch
from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset
from timm.layers import apply_test_time_pool
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
try:
from apex import amp
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
# try:
# from functorch.compile import memory_efficient_fusion
# has_functorch = True
# except ImportError as e:
# has_functorch = False
has_compile = hasattr(torch, 'compile')
import PIL
import requests
import io
import base64
# ImageFile.LOAD_TRUNCATED_IMAGES = True
###
class EndpointHandler():
def __init__(self, path=""):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
# May sacrifice a bit of accuracy, depending on our needs
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
amp_dtype = torch.float16
amp_autocast = partial(torch.autocast, device_type=self.device.type, dtype=amp_dtype)
# data_config = resolve_data_config(vars(args), model=model)
# self.aiGeneratorModel = timm.create_model('eva02_base_patch14_448.mim_in22k_ft_in22k_in1k', num_classes=9, in_chans=3, checkpoint_path=path + 'AIModelDetector.pth-6ff3631e.pth')
self.aiArtModel = timm.create_model('eva02_base_patch14_448.mim_in22k_ft_in22k_in1k', num_classes=3, in_chans=3, checkpoint_path=path + 'AIArtDetector.pth-af59f7fa.pth')
# self.aiGeneratorModel = self.aiGeneratorModel.to(self.device)
self.aiArtModel = self.aiArtModel.to(self.device)
# self.aiGeneratorModel.eval()
self.aiArtModel.eval()
self.transform = timm.data.create_transform(input_size=(3, 448, 448),
is_training=False,
use_prefetcher=False,
no_aug=False,
scale=None,
ratio=None,
hflip=0,
vflip=0.,
color_jitter=0,
auto_augment=None,
interpolation='bicubic',
# mean=(0.5, 0.5, 0.5),
# std=(0.5, 0.5, 0.5),
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0,
crop_pct=1.0,
# crop_mode='center',
crop_mode='squash',
tf_preprocessing=False,
separate=False)
# assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
# torch._dynamo.reset()
# model = torch.compile(model, backend=args.torchcompile)
self.supported_formats = ["JPEG", "PNG", "BMP", "TIFF", "WEBP", "RAW"] #GIF requires its own special implementation to get its frames
print("initialized handler.py successfully")
# self.label_map = {0: 'Dall-E 2', 1: 'DiscoDiff', 2: 'Midjourney', 3: 'NightCafe', 4: 'NovelAI', 5: 'Stable Diffusion', 6: 'StarryAI', 7: 'WomboDream', 8: 'Artbreeder'}
def __call__(self, data):
"""
data args:
inputs: Dict[str, Any]
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs")
if len(inputs) > 50:
return {'error': 'Exceeds max limit of images (50)'}
image_paths = inputs #['https://google_image.png', '']
batch_size = 1 # Set your desired batch size
results = {}
for i in range(0, len(image_paths), batch_size): # For each batch
batch_paths = image_paths[i:i+batch_size]
validUrls = []
batch_images = []
for j, src in enumerate(batch_paths): # Get all valid images open and inputted in batch_images
try:
# Image.open(batch_paths[j]).load() # Tests if image is okay to run inference on.
pos = src.find("base64")
if pos != -1:
# Assuming base64_str is the string value without 'data:image/jpeg;base64,'
new = Image.open(io.BytesIO(base64.decodebytes(bytes(src[pos+7:], "utf-8")))).convert("RGB")
# new.load() Necessary? Does this catch any edge cases? Without this, we don't actually load the image pixels.
batch_images.append(new)
validUrls.append(src)
else:
try:
# r = requests.get(src, stream=True)
# r.raw.decode_content = True
# new = Image.open(r.raw).convert("RGB")
# new = Image.open(urlopen(src))
headers = {
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36'
}
r = requests.get(src, headers=headers)
new = Image.open(io.BytesIO(r.content)).convert("RGB")
# new.load()
batch_images.append(new)
validUrls.append(src)
except Exception as e:
results[src] = {'error': 'Failed to process image'}
# invalid_indices.append(j)
continue
# batch_images.append(batch_paths[j])
except Exception as e:
results[src] = {'error': 'Failed to process image w/ base64 in url'}
continue
# width, height = new.size
# if (width < 250 or height < 250) and len(request.data['srcs']) == 1:
# res['error'] = 'Please use a higher quality image'
# return JsonResponse(res, safe=False, status=status.HTTP_400_BAD_REQUEST)
batch_tensors = torch.stack([self.transform(img).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) for img in batch_images])
# batch_tensors = torch.unsqueeze(batch_tensors, 0)
# batch_images = [Image.open(path) for path in batch_paths]
# batch_tensors = torch.stack([preprocess(img) for img in batch_images])
with torch.no_grad():
output1 = self.aiGeneratorModel(batch_tensors)
for k, tensor in enumerate(output1):
output = tensor.softmax(-1)
output, indice = output.topk(9)
labels = [self.label_map[x] for x in indice.cpu().numpy().tolist()]
probabilities = [round(i * 100, 2) for i in output.cpu().numpy().tolist()]
single_res = {'prob': probabilities, 'indices': labels}
results[validUrls[k]] = single_res
return results
# handler = EndpointHandler()
# handler.__call__({'inputs': ['']})