VamooseBambel commited on
Commit
5307262
·
verified ·
1 Parent(s): 2ac128d

Create nsfw_detector.py

Browse files
Files changed (1) hide show
  1. nsfw_detector.py +42 -0
nsfw_detector.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from transformers import AutoProcessor, FocalNetForImageClassification
4
+ import numpy as np
5
+
6
+ class NSFWDetector:
7
+ def __init__(self):
8
+ self.model_path = "TostAI/nsfw-image-detection-large"
9
+ self.feature_extractor = AutoProcessor.from_pretrained(self.model_path)
10
+ self.model = FocalNetForImageClassification.from_pretrained(self.model_path)
11
+ self.model.eval()
12
+
13
+ self.transform = transforms.Compose([
14
+ transforms.Resize((512, 512)),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
17
+ ])
18
+
19
+ self.label_to_category = {
20
+ "LABEL_0": "Safe",
21
+ "LABEL_1": "Questionable",
22
+ "LABEL_2": "Unsafe"
23
+ }
24
+
25
+ def check_image(self, image):
26
+ # Convert image to RGB if it isn't already
27
+ image = image.convert("RGB")
28
+
29
+ # Process image
30
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
31
+
32
+ # Get prediction
33
+ with torch.no_grad():
34
+ outputs = self.model(**inputs)
35
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
36
+ confidence, predicted = torch.max(probabilities, 1)
37
+
38
+ # Get the label
39
+ label = self.model.config.id2label[predicted.item()]
40
+ category = self.label_to_category.get(label, label)
41
+
42
+ return category != "Safe", category, confidence.item() * 100