Spaces:
Running
Running
File size: 3,854 Bytes
ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 823554c ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 b6af6fc ec8aa40 b6af6fc 823554c ec8aa40 b6af6fc ec8aa40 2a18958 ec8aa40 b6af6fc 1668fa4 ec8aa40 b6af6fc c906ee2 b6af6fc c906ee2 b6af6fc c906ee2 b6af6fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
import numpy as np
import json
import os
from PIL import Image
import onnxruntime as rt
class ONNXModel:
def __init__(self, dir_path) -> None:
"""Load model metadata and initialize ONNX session."""
model_dir = os.path.dirname(dir_path)
with open(os.path.join(model_dir, "signature.json"), "r") as f:
self.signature = json.load(f)
self.model_file = os.path.join(model_dir, self.signature.get("filename"))
if not os.path.isfile(self.model_file):
raise FileNotFoundError("Model file does not exist.")
self.signature_inputs = self.signature.get("inputs")
self.signature_outputs = self.signature.get("outputs")
if "Image" not in self.signature_inputs:
raise ValueError("ONNX model must have an 'Image' input. Check signature.json.")
# Check export version
version = self.signature.get("export_model_version")
if version is None or version != EXPORT_MODEL_VERSION:
print(f"Warning: Expected model version {EXPORT_MODEL_VERSION}, but found {version}.")
self.session = None
def load(self) -> None:
"""Load the ONNX model with execution providers."""
self.session = rt.InferenceSession(self.model_file, providers=["CPUExecutionProvider"])
def predict(self, image: Image.Image) -> dict:
"""Process image and run ONNX model inference."""
img = self.process_image(image, self.signature_inputs["Image"]["shape"])
feed = {self.signature_inputs["Image"]["name"]: [img]}
output_names = [self.signature_outputs[key]["name"] for key in self.signature_outputs]
outputs = self.session.run(output_names=output_names, input_feed=feed)
return self.process_output(outputs)
def process_image(self, image: Image.Image, input_shape: list) -> np.ndarray:
"""Resize and normalize the image."""
width, height = image.size
if image.mode != "RGB":
image = image.convert("RGB")
square_size = min(width, height)
left = (width - square_size) / 2
top = (height - square_size) / 2
right = (width + square_size) / 2
bottom = (height + square_size) / 2
image = image.crop((left, top, right, bottom))
input_width, input_height = input_shape[1:3]
image = image.resize((input_width, input_height))
image = np.asarray(image) / 255.0
return image.astype(np.float32)
def process_output(self, outputs: list) -> dict:
"""Format the model output."""
out_keys = ["label", "confidence"]
results = {key: outputs[i].tolist()[0] for i, key in enumerate(self.signature_outputs)}
confs = results["Confidences"]
labels = self.signature["classes"]["Label"]
output = [dict(zip(out_keys, group)) for group in zip(labels, confs)]
return {"predictions": sorted(output, key=lambda x: x["confidence"], reverse=True)}
EXPORT_MODEL_VERSION = 1
model = ONNXModel(dir_path="model.onnx")
model.load()
def predict(image):
"""Run inference on the given image."""
image = Image.fromarray(np.uint8(image), "RGB")
prediction = model.predict(image)
for output in prediction["predictions"]:
output["confidence"] = round(output["confidence"], 4)
return prediction
inputs = gr.Image(image_mode="RGB")
outputs = gr.JSON()
description = (
"This is a web interface for the Naked Detector model. "
"Upload an image and get predictions for the presence of nudity.\n"
"Model and website created by KUO SUKO, C110156115 NKUST."
)
interface = gr.Interface(
fn=predict,
inputs=inputs,
outputs=outputs,
title="Naked Detector",
description=description
)
interface.launch()
# this is changed by ChatGPT, if it run like a shit. don't blame me >< |