tanlocc commited on
Commit
9f52a4a
Β·
1 Parent(s): 4dc7058

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -1
app.py CHANGED
@@ -10,6 +10,7 @@ from imgutils.utils import open_onnx_model
10
 
11
  import random
12
  from typing import List
 
13
  from base import ONNXBaseTask
14
  from utils import prepare_input_wraper
15
 
@@ -29,6 +30,40 @@ def _onnx_model(name):
29
  f'{name}'
30
  ))
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def _image_preprocess(image, size: int = 224) -> np.ndarray:
33
  image = load_image(image, mode='RGB').resize((size, size), Image.NEAREST)
34
  return (np.array(image) / 255.0)[None, ...]
@@ -55,7 +90,7 @@ if __name__ == '__main__':
55
  gr_ratings = gr.Label(label='Ratings')
56
 
57
  gr_btn_submit.click(
58
- predict,
59
  inputs=[gr_input_image, gr_model],
60
  outputs=[gr_ratings],
61
  )
 
10
 
11
  import random
12
  from typing import List
13
+ from dataclasses import dataclass
14
  from base import ONNXBaseTask
15
  from utils import prepare_input_wraper
16
 
 
30
  f'{name}'
31
  ))
32
 
33
+ @dataclass
34
+ class NSFWPrediction:
35
+ label: str
36
+ score: float
37
+
38
+ class CheckNSFWTask(ONNXBaseTask):
39
+ classes: List[str] = ['drawings', 'hentai', 'neutral', 'porn', 'sexy']
40
+
41
+ def __init__(self, weight: str):
42
+ super().__init__(weight)
43
+
44
+ def __call__(self, image) -> NSFWPrediction:
45
+ return self.call(image)
46
+
47
+ def process_output(self, raw_outputs) -> NSFWPrediction:
48
+ probabilities = raw_outputs[0][0]
49
+ max_prob_index = np.argmax(probabilities)
50
+ max_prob_score = probabilities[max_prob_index]
51
+
52
+ predicted_class = self.classes[max_prob_index]
53
+
54
+ return NSFWPrediction(label=predicted_class, score=max_prob_score.tolist())
55
+
56
+ def setup_prepare_input(self):
57
+ return prepare_input_wraper(
58
+ inter=1,
59
+ color_space="RGB",
60
+ mean=None,
61
+ std=None,
62
+ is_scale=True,
63
+ channel_first=False
64
+ )
65
+
66
+
67
  def _image_preprocess(image, size: int = 224) -> np.ndarray:
68
  image = load_image(image, mode='RGB').resize((size, size), Image.NEAREST)
69
  return (np.array(image) / 255.0)[None, ...]
 
90
  gr_ratings = gr.Label(label='Ratings')
91
 
92
  gr_btn_submit.click(
93
+ CheckNSFWTask,
94
  inputs=[gr_input_image, gr_model],
95
  outputs=[gr_ratings],
96
  )