tanlocc commited on
Commit
d8240fe
Β·
1 Parent(s): 9f52a4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -42
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from functools import lru_cache
3
 
@@ -8,13 +9,6 @@ from huggingface_hub import hf_hub_download
8
  from imgutils.data import load_image
9
  from imgutils.utils import open_onnx_model
10
 
11
- import random
12
- from typing import List
13
- from dataclasses import dataclass
14
- from base import ONNXBaseTask
15
- from utils import prepare_input_wraper
16
-
17
-
18
  _MODELS = [
19
  ('content_moderation.onnx', 224),
20
  ]
@@ -30,39 +24,6 @@ def _onnx_model(name):
30
  f'{name}'
31
  ))
32
 
33
- @dataclass
34
- class NSFWPrediction:
35
- label: str
36
- score: float
37
-
38
- class CheckNSFWTask(ONNXBaseTask):
39
- classes: List[str] = ['drawings', 'hentai', 'neutral', 'porn', 'sexy']
40
-
41
- def __init__(self, weight: str):
42
- super().__init__(weight)
43
-
44
- def __call__(self, image) -> NSFWPrediction:
45
- return self.call(image)
46
-
47
- def process_output(self, raw_outputs) -> NSFWPrediction:
48
- probabilities = raw_outputs[0][0]
49
- max_prob_index = np.argmax(probabilities)
50
- max_prob_score = probabilities[max_prob_index]
51
-
52
- predicted_class = self.classes[max_prob_index]
53
-
54
- return NSFWPrediction(label=predicted_class, score=max_prob_score.tolist())
55
-
56
- def setup_prepare_input(self):
57
- return prepare_input_wraper(
58
- inter=1,
59
- color_space="RGB",
60
- mean=None,
61
- std=None,
62
- is_scale=True,
63
- channel_first=False
64
- )
65
-
66
 
67
  def _image_preprocess(image, size: int = 224) -> np.ndarray:
68
  image = load_image(image, mode='RGB').resize((size, size), Image.NEAREST)
@@ -74,7 +35,7 @@ _LABELS = ['drawings', 'hentai', 'neutral', 'porn', 'sexy']
74
 
75
  def predict(image, model_name):
76
  input_ = _image_preprocess(image, _MODEL_TO_SIZE[model_name]).astype(np.float32)
77
- output_, = _onnx_model(model_name).run(['dense_3'], {'input_1': input_})
78
  return dict(zip(_LABELS, map(float, output_[0])))
79
 
80
 
@@ -90,7 +51,7 @@ if __name__ == '__main__':
90
  gr_ratings = gr.Label(label='Ratings')
91
 
92
  gr_btn_submit.click(
93
- CheckNSFWTask,
94
  inputs=[gr_input_image, gr_model],
95
  outputs=[gr_ratings],
96
  )
 
1
+
2
  import os
3
  from functools import lru_cache
4
 
 
9
  from imgutils.data import load_image
10
  from imgutils.utils import open_onnx_model
11
 
 
 
 
 
 
 
 
12
  _MODELS = [
13
  ('content_moderation.onnx', 224),
14
  ]
 
24
  f'{name}'
25
  ))
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def _image_preprocess(image, size: int = 224) -> np.ndarray:
29
  image = load_image(image, mode='RGB').resize((size, size), Image.NEAREST)
 
35
 
36
  def predict(image, model_name):
37
  input_ = _image_preprocess(image, _MODEL_TO_SIZE[model_name]).astype(np.float32)
38
+ output_, = _onnx_model(model_name).run()
39
  return dict(zip(_LABELS, map(float, output_[0])))
40
 
41
 
 
51
  gr_ratings = gr.Label(label='Ratings')
52
 
53
  gr_btn_submit.click(
54
+ predict,
55
  inputs=[gr_input_image, gr_model],
56
  outputs=[gr_ratings],
57
  )