File size: 2,381 Bytes
6a9d129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from torchvision import transforms
from transformers import AutoProcessor, FocalNetForImageClassification
from PIL import Image, ImageDraw, ImageFont
import numpy as np

class NSFWDetector:
    def __init__(self):
        self.model_path = "MichalMlodawski/nsfw-image-detection-large"
        self.feature_extractor = AutoProcessor.from_pretrained(self.model_path)
        self.model = FocalNetForImageClassification.from_pretrained(self.model_path)
        self.model.eval()
        
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        self.label_to_category = {
            "LABEL_0": "Safe",
            "LABEL_1": "Questionable",
            "LABEL_2": "Unsafe"
        }

    def check_image(self, image):
        # Convert image to RGB if it isn't already
        image = image.convert("RGB")
        
        # Process image
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        
        # Get prediction
        with torch.no_grad():
            outputs = self.model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            confidence, predicted = torch.max(probabilities, 1)
        
        # Get the label
        label = self.model.config.id2label[predicted.item()]
        category = self.label_to_category.get(label, label)
        
        return category != "Safe", category, confidence.item() * 100

def create_error_image(message="NSFW Content Detected"):
    # Create a black image
    img = Image.new('RGB', (512, 512), color='black')
    draw = ImageDraw.Draw(img)
    
    # Use default font
    try:
        # Try to get a default system font
        font = ImageFont.load_default()
        
        # Calculate text position to center it
        text_bbox = draw.textbbox((0, 0), message, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]
        
        x = (512 - text_width) // 2
        y = (512 - text_height) // 2
        
        # Draw white text
        draw.text((x, y), message, fill='white', font=font)
        
    except Exception as e:
        print(f"Error adding text to image: {e}")
    
    return img