File size: 7,673 Bytes
6737fe8 1284b56 6737fe8 1284b56 6737fe8 1284b56 6737fe8 1284b56 6737fe8 1284b56 6737fe8 1284b56 6737fe8 1284b56 6737fe8 1284b56 6737fe8 1284b56 6737fe8 1284b56 6737fe8 1284b56 6737fe8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import timm
import torch
from PIL import Image
from timm.utils import ParseKwargs
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
###
import os
import time
from contextlib import suppress
from functools import partial
import numpy as np
import pandas as pd
import torch
from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset
from timm.layers import apply_test_time_pool
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
try:
from apex import amp
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
# try:
# from functorch.compile import memory_efficient_fusion
# has_functorch = True
# except ImportError as e:
# has_functorch = False
has_compile = hasattr(torch, 'compile')
import PIL
import requests
import io
import base64
# ImageFile.LOAD_TRUNCATED_IMAGES = True
###
class EndpointHandler():
def __init__(self, path=""):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
# May sacrifice a bit of accuracy, depending on our needs
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
amp_dtype = torch.float16
amp_autocast = partial(torch.autocast, device_type=self.device.type, dtype=amp_dtype)
# data_config = resolve_data_config(vars(args), model=model)
# self.aiGeneratorModel = timm.create_model('eva02_base_patch14_448.mim_in22k_ft_in22k_in1k', num_classes=9, in_chans=3, checkpoint_path=path + 'AIModelDetector.pth-6ff3631e.pth')
self.aiArtModel = timm.create_model('eva02_base_patch14_448.mim_in22k_ft_in22k_in1k', num_classes=3, in_chans=3, checkpoint_path=path + 'AIArtDetector.pth-af59f7fa.pth')
# self.aiGeneratorModel = self.aiGeneratorModel.to(self.device)
self.aiArtModel = self.aiArtModel.to(self.device)
# self.aiGeneratorModel.eval()
self.aiArtModel.eval()
self.transform = timm.data.create_transform(input_size=(3, 448, 448),
is_training=False,
use_prefetcher=False,
no_aug=False,
scale=None,
ratio=None,
hflip=0,
vflip=0.,
color_jitter=0,
auto_augment=None,
interpolation='bicubic',
# mean=(0.5, 0.5, 0.5),
# std=(0.5, 0.5, 0.5),
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0,
crop_pct=1.0,
# crop_mode='center',
crop_mode='squash',
tf_preprocessing=False,
separate=False)
# assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
# torch._dynamo.reset()
# model = torch.compile(model, backend=args.torchcompile)
self.supported_formats = ["JPEG", "PNG", "BMP", "TIFF", "WEBP", "RAW"] #GIF requires its own special implementation to get its frames
print("initialized handler.py successfully")
# self.label_map = {0: 'Dall-E 2', 1: 'DiscoDiff', 2: 'Midjourney', 3: 'NightCafe', 4: 'NovelAI', 5: 'Stable Diffusion', 6: 'StarryAI', 7: 'WomboDream', 8: 'Artbreeder'}
def __call__(self, data):
"""
data args:
inputs: Dict[str, Any]
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs")
if len(inputs) > 50:
return {'error': 'Exceeds max limit of images (50)'}
image_paths = inputs #['https://google_image.png', '']
batch_size = 1 # Set your desired batch size
results = {}
for i in range(0, len(image_paths), batch_size): # For each batch
batch_paths = image_paths[i:i+batch_size]
validUrls = []
batch_images = []
for j, src in enumerate(batch_paths): # Get all valid images open and inputted in batch_images
try:
# Image.open(batch_paths[j]).load() # Tests if image is okay to run inference on.
pos = src.find("base64")
if pos != -1:
# Assuming base64_str is the string value without 'data:image/jpeg;base64,'
new = Image.open(io.BytesIO(base64.decodebytes(bytes(src[pos+7:], "utf-8")))).convert("RGB")
# new.load() Necessary? Does this catch any edge cases? Without this, we don't actually load the image pixels.
batch_images.append(new)
validUrls.append(src)
else:
try:
# r = requests.get(src, stream=True)
# r.raw.decode_content = True
# new = Image.open(r.raw).convert("RGB")
# new = Image.open(urlopen(src))
headers = {
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36'
}
r = requests.get(src, headers=headers)
new = Image.open(io.BytesIO(r.content)).convert("RGB")
# new.load()
batch_images.append(new)
validUrls.append(src)
except Exception as e:
results[src] = {'error': 'Failed to process image'}
# invalid_indices.append(j)
continue
# batch_images.append(batch_paths[j])
except Exception as e:
results[src] = {'error': 'Failed to process image w/ base64 in url'}
continue
# width, height = new.size
# if (width < 250 or height < 250) and len(request.data['srcs']) == 1:
# res['error'] = 'Please use a higher quality image'
# return JsonResponse(res, safe=False, status=status.HTTP_400_BAD_REQUEST)
batch_tensors = torch.stack([self.transform(img).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) for img in batch_images])
# batch_tensors = torch.unsqueeze(batch_tensors, 0)
# batch_images = [Image.open(path) for path in batch_paths]
# batch_tensors = torch.stack([preprocess(img) for img in batch_images])
with torch.no_grad():
output1 = self.aiGeneratorModel(batch_tensors)
for k, tensor in enumerate(output1):
output = tensor.softmax(-1)
output, indice = output.topk(9)
labels = [self.label_map[x] for x in indice.cpu().numpy().tolist()]
probabilities = [round(i * 100, 2) for i in output.cpu().numpy().tolist()]
single_res = {'prob': probabilities, 'indices': labels}
results[validUrls[k]] = single_res
return results
# handler = EndpointHandler()
# handler.__call__({'inputs': ['']})
|