import torch from torchvision import transforms from transformers import AutoProcessor, FocalNetForImageClassification import numpy as np class NSFWDetector: def __init__(self): self.model_path = "TostAI/nsfw-image-detection-large" self.feature_extractor = AutoProcessor.from_pretrained(self.model_path) self.model = FocalNetForImageClassification.from_pretrained(self.model_path) self.model.eval() self.transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) self.label_to_category = { "LABEL_0": "Safe", "LABEL_1": "Questionable", "LABEL_2": "Unsafe" } def check_image(self, image): # Convert image to RGB if it isn't already image = image.convert("RGB") # Process image inputs = self.feature_extractor(images=image, return_tensors="pt") # Get prediction with torch.no_grad(): outputs = self.model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) confidence, predicted = torch.max(probabilities, 1) # Get the label label = self.model.config.id2label[predicted.item()] category = self.label_to_category.get(label, label) return category != "Safe", category, confidence.item() * 100