VamooseBambel commited on
Commit
6a9d129
·
verified ·
1 Parent(s): 4afe8be

Create nsfw_detector.py

Browse files
Files changed (1) hide show
  1. nsfw_detector.py +69 -0
nsfw_detector.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from transformers import AutoProcessor, FocalNetForImageClassification
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import numpy as np
6
+
7
+ class NSFWDetector:
8
+ def __init__(self):
9
+ self.model_path = "MichalMlodawski/nsfw-image-detection-large"
10
+ self.feature_extractor = AutoProcessor.from_pretrained(self.model_path)
11
+ self.model = FocalNetForImageClassification.from_pretrained(self.model_path)
12
+ self.model.eval()
13
+
14
+ self.transform = transforms.Compose([
15
+ transforms.Resize((512, 512)),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
+ ])
19
+
20
+ self.label_to_category = {
21
+ "LABEL_0": "Safe",
22
+ "LABEL_1": "Questionable",
23
+ "LABEL_2": "Unsafe"
24
+ }
25
+
26
+ def check_image(self, image):
27
+ # Convert image to RGB if it isn't already
28
+ image = image.convert("RGB")
29
+
30
+ # Process image
31
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
32
+
33
+ # Get prediction
34
+ with torch.no_grad():
35
+ outputs = self.model(**inputs)
36
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
37
+ confidence, predicted = torch.max(probabilities, 1)
38
+
39
+ # Get the label
40
+ label = self.model.config.id2label[predicted.item()]
41
+ category = self.label_to_category.get(label, label)
42
+
43
+ return category != "Safe", category, confidence.item() * 100
44
+
45
+ def create_error_image(message="NSFW Content Detected"):
46
+ # Create a black image
47
+ img = Image.new('RGB', (512, 512), color='black')
48
+ draw = ImageDraw.Draw(img)
49
+
50
+ # Use default font
51
+ try:
52
+ # Try to get a default system font
53
+ font = ImageFont.load_default()
54
+
55
+ # Calculate text position to center it
56
+ text_bbox = draw.textbbox((0, 0), message, font=font)
57
+ text_width = text_bbox[2] - text_bbox[0]
58
+ text_height = text_bbox[3] - text_bbox[1]
59
+
60
+ x = (512 - text_width) // 2
61
+ y = (512 - text_height) // 2
62
+
63
+ # Draw white text
64
+ draw.text((x, y), message, fill='white', font=font)
65
+
66
+ except Exception as e:
67
+ print(f"Error adding text to image: {e}")
68
+
69
+ return img