Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,381 Bytes
6a9d129 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
from torchvision import transforms
from transformers import AutoProcessor, FocalNetForImageClassification
from PIL import Image, ImageDraw, ImageFont
import numpy as np
class NSFWDetector:
def __init__(self):
self.model_path = "MichalMlodawski/nsfw-image-detection-large"
self.feature_extractor = AutoProcessor.from_pretrained(self.model_path)
self.model = FocalNetForImageClassification.from_pretrained(self.model_path)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.label_to_category = {
"LABEL_0": "Safe",
"LABEL_1": "Questionable",
"LABEL_2": "Unsafe"
}
def check_image(self, image):
# Convert image to RGB if it isn't already
image = image.convert("RGB")
# Process image
inputs = self.feature_extractor(images=image, return_tensors="pt")
# Get prediction
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted = torch.max(probabilities, 1)
# Get the label
label = self.model.config.id2label[predicted.item()]
category = self.label_to_category.get(label, label)
return category != "Safe", category, confidence.item() * 100
def create_error_image(message="NSFW Content Detected"):
# Create a black image
img = Image.new('RGB', (512, 512), color='black')
draw = ImageDraw.Draw(img)
# Use default font
try:
# Try to get a default system font
font = ImageFont.load_default()
# Calculate text position to center it
text_bbox = draw.textbbox((0, 0), message, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
x = (512 - text_width) // 2
y = (512 - text_height) // 2
# Draw white text
draw.text((x, y), message, fill='white', font=font)
except Exception as e:
print(f"Error adding text to image: {e}")
return img |