Sana-1.6B / nsfw_detector.py
VamooseBambel's picture
Create nsfw_detector.py
6a9d129 verified
raw
history blame
2.38 kB
import torch
from torchvision import transforms
from transformers import AutoProcessor, FocalNetForImageClassification
from PIL import Image, ImageDraw, ImageFont
import numpy as np
class NSFWDetector:
def __init__(self):
self.model_path = "MichalMlodawski/nsfw-image-detection-large"
self.feature_extractor = AutoProcessor.from_pretrained(self.model_path)
self.model = FocalNetForImageClassification.from_pretrained(self.model_path)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.label_to_category = {
"LABEL_0": "Safe",
"LABEL_1": "Questionable",
"LABEL_2": "Unsafe"
}
def check_image(self, image):
# Convert image to RGB if it isn't already
image = image.convert("RGB")
# Process image
inputs = self.feature_extractor(images=image, return_tensors="pt")
# Get prediction
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted = torch.max(probabilities, 1)
# Get the label
label = self.model.config.id2label[predicted.item()]
category = self.label_to_category.get(label, label)
return category != "Safe", category, confidence.item() * 100
def create_error_image(message="NSFW Content Detected"):
# Create a black image
img = Image.new('RGB', (512, 512), color='black')
draw = ImageDraw.Draw(img)
# Use default font
try:
# Try to get a default system font
font = ImageFont.load_default()
# Calculate text position to center it
text_bbox = draw.textbbox((0, 0), message, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
x = (512 - text_width) // 2
y = (512 - text_height) // 2
# Draw white text
draw.text((x, y), message, fill='white', font=font)
except Exception as e:
print(f"Error adding text to image: {e}")
return img