Spaces:
Runtime error
Runtime error
Commit
·
61dbd85
1
Parent(s):
50ffd5f
Delete nsfw_detector.py
Browse files- nsfw_detector.py +0 -65
nsfw_detector.py
DELETED
|
@@ -1,65 +0,0 @@
|
|
| 1 |
-
from torchvision.transforms import Normalize
|
| 2 |
-
import torchvision.transforms as T
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
from PIL import Image
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
import timm
|
| 8 |
-
from tqdm import tqdm
|
| 9 |
-
|
| 10 |
-
# https://github.com/Whiax/NSFW-Classifier/raw/main/nsfwmodel_281.pth
|
| 11 |
-
normalize_t = Normalize((0.4814, 0.4578, 0.4082), (0.2686, 0.2613, 0.2757))
|
| 12 |
-
|
| 13 |
-
#nsfw classifier
|
| 14 |
-
class NSFWClassifier(nn.Module):
|
| 15 |
-
def __init__(self):
|
| 16 |
-
super().__init__()
|
| 17 |
-
nsfw_model=self
|
| 18 |
-
nsfw_model.root_model = timm.create_model('convnext_base_in22ft1k', pretrained=True)
|
| 19 |
-
nsfw_model.linear_probe = nn.Linear(1024, 1, bias=False)
|
| 20 |
-
|
| 21 |
-
def forward(self, x):
|
| 22 |
-
nsfw_model = self
|
| 23 |
-
x = normalize_t(x)
|
| 24 |
-
x = nsfw_model.root_model.stem(x)
|
| 25 |
-
x = nsfw_model.root_model.stages(x)
|
| 26 |
-
x = nsfw_model.root_model.head.global_pool(x)
|
| 27 |
-
x = nsfw_model.root_model.head.norm(x)
|
| 28 |
-
x = nsfw_model.root_model.head.flatten(x)
|
| 29 |
-
x = nsfw_model.linear_probe(x)
|
| 30 |
-
return x
|
| 31 |
-
|
| 32 |
-
def is_nsfw(self, img_paths, threshold = 0.98):
|
| 33 |
-
skip_step = 1
|
| 34 |
-
total_len = len(img_paths)
|
| 35 |
-
if total_len < 100: skip_step = 1
|
| 36 |
-
if total_len > 100 and total_len < 500: skip_step = 10
|
| 37 |
-
if total_len > 500 and total_len < 1000: skip_step = 20
|
| 38 |
-
if total_len > 1000 and total_len < 10000: skip_step = 50
|
| 39 |
-
if total_len > 10000: skip_step = 100
|
| 40 |
-
|
| 41 |
-
for idx in tqdm(range(0, total_len, skip_step), total=int(total_len // skip_step), desc="Checking for NSFW contents"):
|
| 42 |
-
_img = Image.open(img_paths[idx]).convert('RGB')
|
| 43 |
-
img = _img.resize((224, 224))
|
| 44 |
-
img = np.array(img)/255
|
| 45 |
-
img = T.ToTensor()(img).unsqueeze(0).float()
|
| 46 |
-
if next(self.parameters()).is_cuda:
|
| 47 |
-
img = img.cuda()
|
| 48 |
-
with torch.no_grad():
|
| 49 |
-
score = self.forward(img).sigmoid()[0].item()
|
| 50 |
-
if score > threshold:
|
| 51 |
-
print(f"Detected nsfw score:{score}")
|
| 52 |
-
_img.save("nsfw.jpg")
|
| 53 |
-
return True
|
| 54 |
-
return False
|
| 55 |
-
|
| 56 |
-
def get_nsfw_detector(model_path='nsfwmodel_281.pth', device="cpu"):
|
| 57 |
-
#load base model
|
| 58 |
-
nsfw_model = NSFWClassifier()
|
| 59 |
-
nsfw_model = nsfw_model.eval()
|
| 60 |
-
#load linear weights
|
| 61 |
-
linear_pth = model_path
|
| 62 |
-
linear_state_dict = torch.load(linear_pth, map_location='cpu')
|
| 63 |
-
nsfw_model.linear_probe.load_state_dict(linear_state_dict)
|
| 64 |
-
nsfw_model = nsfw_model.to(device)
|
| 65 |
-
return nsfw_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|