SDXL-Turbo / nsfw_detector.py
VamooseBambel's picture
Create nsfw_detector.py
5307262 verified
import torch
from torchvision import transforms
from transformers import AutoProcessor, FocalNetForImageClassification
import numpy as np
class NSFWDetector:
def __init__(self):
self.model_path = "TostAI/nsfw-image-detection-large"
self.feature_extractor = AutoProcessor.from_pretrained(self.model_path)
self.model = FocalNetForImageClassification.from_pretrained(self.model_path)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.label_to_category = {
"LABEL_0": "Safe",
"LABEL_1": "Questionable",
"LABEL_2": "Unsafe"
}
def check_image(self, image):
# Convert image to RGB if it isn't already
image = image.convert("RGB")
# Process image
inputs = self.feature_extractor(images=image, return_tensors="pt")
# Get prediction
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted = torch.max(probabilities, 1)
# Get the label
label = self.model.config.id2label[predicted.item()]
category = self.label_to_category.get(label, label)
return category != "Safe", category, confidence.item() * 100