File size: 1,538 Bytes
5307262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torchvision import transforms
from transformers import AutoProcessor, FocalNetForImageClassification
import numpy as np

class NSFWDetector:
    def __init__(self):
        self.model_path = "TostAI/nsfw-image-detection-large"
        self.feature_extractor = AutoProcessor.from_pretrained(self.model_path)
        self.model = FocalNetForImageClassification.from_pretrained(self.model_path)
        self.model.eval()
        
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        self.label_to_category = {
            "LABEL_0": "Safe",
            "LABEL_1": "Questionable",
            "LABEL_2": "Unsafe"
        }

    def check_image(self, image):
        # Convert image to RGB if it isn't already
        image = image.convert("RGB")
        
        # Process image
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        
        # Get prediction
        with torch.no_grad():
            outputs = self.model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            confidence, predicted = torch.max(probabilities, 1)
        
        # Get the label
        label = self.model.config.id2label[predicted.item()]
        category = self.label_to_category.get(label, label)
        
        return category != "Safe", category, confidence.item() * 100