dlsmallw commited on
Commit
a189dd1
·
1 Parent(s): 4f70c5d

Task-292 Implement method for deploying models

Browse files
Files changed (6) hide show
  1. Pipfile +2 -0
  2. Pipfile.lock +20 -42
  3. app.py +48 -49
  4. config.py +22 -0
  5. scripts/__init__.py +1 -0
  6. scripts/predict.py +32 -8
Pipfile CHANGED
@@ -10,6 +10,8 @@ numpy = "*"
10
  st-annotated-text = "*"
11
  transformers = "*"
12
  torch = "*"
 
 
13
 
14
  [dev-packages]
15
 
 
10
  st-annotated-text = "*"
11
  transformers = "*"
12
  torch = "*"
13
+ huggingface-hub = "*"
14
+ joblib = "*"
15
 
16
  [dev-packages]
17
 
Pipfile.lock CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_meta": {
3
  "hash": {
4
- "sha256": "d44f8f17557914a1bc97b5e9ce219979a85e81b74eb603b3c0c6920cac065c91"
5
  },
6
  "pipfile-spec": 6,
7
  "requires": {
@@ -211,11 +211,12 @@
211
  },
212
  "huggingface-hub": {
213
  "hashes": [
214
- "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5",
215
- "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250"
216
  ],
 
217
  "markers": "python_full_version >= '3.8.0'",
218
- "version": "==0.29.1"
219
  },
220
  "idna": {
221
  "hashes": [
@@ -227,11 +228,20 @@
227
  },
228
  "jinja2": {
229
  "hashes": [
230
- "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb",
231
- "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb"
232
  ],
233
  "markers": "python_version >= '3.7'",
234
- "version": "==3.1.5"
 
 
 
 
 
 
 
 
 
235
  },
236
  "jsonschema": {
237
  "hashes": [
@@ -249,14 +259,6 @@
249
  "markers": "python_version >= '3.9'",
250
  "version": "==2024.10.1"
251
  },
252
- "markdown-it-py": {
253
- "hashes": [
254
- "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1",
255
- "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"
256
- ],
257
- "markers": "python_version >= '3.8'",
258
- "version": "==3.0.0"
259
- },
260
  "markupsafe": {
261
  "hashes": [
262
  "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4",
@@ -324,14 +326,6 @@
324
  "markers": "python_version >= '3.9'",
325
  "version": "==3.0.2"
326
  },
327
- "mdurl": {
328
- "hashes": [
329
- "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8",
330
- "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"
331
- ],
332
- "markers": "python_version >= '3.7'",
333
- "version": "==0.1.2"
334
- },
335
  "mpmath": {
336
  "hashes": [
337
  "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f",
@@ -624,14 +618,6 @@
624
  "markers": "python_version >= '3.8'",
625
  "version": "==0.9.1"
626
  },
627
- "pygments": {
628
- "hashes": [
629
- "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f",
630
- "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"
631
- ],
632
- "markers": "python_version >= '3.8'",
633
- "version": "==2.19.1"
634
- },
635
  "python-dateutil": {
636
  "hashes": [
637
  "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3",
@@ -730,14 +716,6 @@
730
  "markers": "python_version >= '3.8'",
731
  "version": "==2.32.3"
732
  },
733
- "rich": {
734
- "hashes": [
735
- "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098",
736
- "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90"
737
- ],
738
- "markers": "python_full_version >= '3.8.0'",
739
- "version": "==13.9.4"
740
- },
741
  "rpds-py": {
742
  "hashes": [
743
  "sha256:09cd7dbcb673eb60518231e02874df66ec1296c01a4fcd733875755c02014b19",
@@ -903,12 +881,12 @@
903
  },
904
  "streamlit": {
905
  "hashes": [
906
- "sha256:62026dbdcb482790933f658b096d7dd58fa70da89c1f06fbc3658b91dcd4dab2",
907
- "sha256:e2516c7fcd17a11a85cc1999fae58ace0a6458e2b4c1a411ed3d75b1aee2eb93"
908
  ],
909
  "index": "pypi",
910
  "markers": "python_version >= '3.9' and python_full_version != '3.9.7'",
911
- "version": "==1.42.2"
912
  },
913
  "sympy": {
914
  "hashes": [
 
1
  {
2
  "_meta": {
3
  "hash": {
4
+ "sha256": "c52664113fb789224f8338560a034a86739fe4d813ded69beb069fdc571c1fd4"
5
  },
6
  "pipfile-spec": 6,
7
  "requires": {
 
211
  },
212
  "huggingface-hub": {
213
  "hashes": [
214
+ "sha256:590b29c0dcbd0ee4b7b023714dc1ad8563fe4a68a91463438b74e980d28afaf3",
215
+ "sha256:c56f20fca09ef19da84dcde2b76379ecdaddf390b083f59f166715584953307d"
216
  ],
217
+ "index": "pypi",
218
  "markers": "python_full_version >= '3.8.0'",
219
+ "version": "==0.29.2"
220
  },
221
  "idna": {
222
  "hashes": [
 
228
  },
229
  "jinja2": {
230
  "hashes": [
231
+ "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d",
232
+ "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"
233
  ],
234
  "markers": "python_version >= '3.7'",
235
+ "version": "==3.1.6"
236
+ },
237
+ "joblib": {
238
+ "hashes": [
239
+ "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6",
240
+ "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"
241
+ ],
242
+ "index": "pypi",
243
+ "markers": "python_version >= '3.8'",
244
+ "version": "==1.4.2"
245
  },
246
  "jsonschema": {
247
  "hashes": [
 
259
  "markers": "python_version >= '3.9'",
260
  "version": "==2024.10.1"
261
  },
 
 
 
 
 
 
 
 
262
  "markupsafe": {
263
  "hashes": [
264
  "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4",
 
326
  "markers": "python_version >= '3.9'",
327
  "version": "==3.0.2"
328
  },
 
 
 
 
 
 
 
 
329
  "mpmath": {
330
  "hashes": [
331
  "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f",
 
618
  "markers": "python_version >= '3.8'",
619
  "version": "==0.9.1"
620
  },
 
 
 
 
 
 
 
 
621
  "python-dateutil": {
622
  "hashes": [
623
  "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3",
 
716
  "markers": "python_version >= '3.8'",
717
  "version": "==2.32.3"
718
  },
 
 
 
 
 
 
 
 
719
  "rpds-py": {
720
  "hashes": [
721
  "sha256:09cd7dbcb673eb60518231e02874df66ec1296c01a4fcd733875755c02014b19",
 
881
  },
882
  "streamlit": {
883
  "hashes": [
884
+ "sha256:c10c09f9d1251fa7f975dd360572f03cabc82b174f080e323bf7e556103c22e0",
885
+ "sha256:cf94b1e9f1de75e4e383df53745230feaac4ac7a7e1f14a3ea362df134db8510"
886
  ],
887
  "index": "pypi",
888
  "markers": "python_version >= '3.9' and python_full_version != '3.9.7'",
889
+ "version": "==1.43.0"
890
  },
891
  "sympy": {
892
  "hashes": [
app.py CHANGED
@@ -1,19 +1,17 @@
1
  import streamlit as st
2
  import pandas as pd
3
- from annotated_text import annotated_text, annotation
4
  import time
5
- from random import randint, uniform
6
  from scripts.predict import InferenceHandler
7
- from pathlib import Path
8
-
9
- ROOT = Path(__file__).resolve().parents[0]
10
- st.write(ROOT)
11
- MODELS_DIR = ROOT / 'models'
12
- BIN_MODEL_PATH = MODELS_DIR / 'binary_classification'
13
- ML_MODEL_PATH = MODELS_DIR / 'multilabel_regression'
14
 
15
  history_df = pd.DataFrame(data=[], columns=['Text', 'Classification', 'Gender', 'Race', 'Sexuality', 'Disability', 'Religion', 'Unspecified'])
16
- ih = InferenceHandler(BIN_MODEL_PATH, ML_MODEL_PATH)
 
 
 
 
 
 
17
 
18
  def extract_data(json_obj):
19
  row_data = []
@@ -58,40 +56,38 @@ def output_results(res):
58
 
59
  if len(at_list) > 0:
60
  annotated_text(at_list)
61
-
62
-
63
- # def test_results(text):
64
- # test_val = int(randint(0, 1))
65
- # res_obj = {
66
- # 'raw_text': text,
67
- # 'text_sentiment': 'Discriminatory' if test_val == 1 else 'Non-Discriminatory',
68
- # 'numerical_sentiment': test_val,
69
- # 'category_sentiments': {
70
- # 'Gender': None if test_val == 0 else uniform(0.0, 1.0),
71
- # 'Race': None if test_val == 0 else uniform(0.0, 1.0),
72
- # 'Sexuality': None if test_val == 0 else uniform(0.0, 1.0),
73
- # 'Disability': None if test_val == 0 else uniform(0.0, 1.0),
74
- # 'Religion': None if test_val == 0 else uniform(0.0, 1.0),
75
- # 'Unspecified': None if test_val == 0 else uniform(0.0, 1.0)
76
- # }
77
- # }
78
- # return res_obj
79
-
80
 
 
81
  def analyze_text(text):
82
- res = None
83
- with rc:
84
- with st.spinner("Processing...", show_time=True) as spnr:
85
- time.sleep(5)
86
- res = ih.classify_text(text)
87
- del spnr
88
-
89
- if res is not None:
90
- st.session_state.results.append(res)
91
- history_df.loc[-1] = extract_data(res)
92
- output_results(res)
 
93
 
94
  st.title('NLPinitiative Text Classifier')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  tab1, tab2 = st.tabs(['Classifier', 'About This App'])
96
 
97
  if "results" not in st.session_state:
@@ -102,20 +98,23 @@ load_history()
102
  with tab1:
103
  "Text Classifier for determining if entered text is discriminatory (and the categories of discrimination) or Non-Discriminatory."
104
 
105
- with st.container():
106
- with st.expander('History'):
107
- st.write(history_df)
108
-
109
- rc = st.container()
110
-
111
  text_form = st.form(key='classifier', clear_on_submit=True, enter_to_submit=True)
112
  with text_form:
113
- text_area = st.text_area('Enter text to classify')
114
- form_btn = st.form_submit_button('submit')
115
 
116
  if entry := text_area:
117
- analyze_text(entry)
 
 
 
118
 
 
 
119
 
120
  with tab2:
121
  st.markdown(
 
1
  import streamlit as st
2
  import pandas as pd
3
+ from annotated_text import annotated_text
4
  import time
 
5
  from scripts.predict import InferenceHandler
 
 
 
 
 
 
 
6
 
7
  history_df = pd.DataFrame(data=[], columns=['Text', 'Classification', 'Gender', 'Race', 'Sexuality', 'Disability', 'Religion', 'Unspecified'])
8
+ rc = None
9
+ ih = None
10
+ entry = None
11
+
12
+ @st.cache_data
13
+ def load_inference_handler(api_token):
14
+ ih = InferenceHandler(api_token)
15
 
16
  def extract_data(json_obj):
17
  row_data = []
 
56
 
57
  if len(at_list) > 0:
58
  annotated_text(at_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ @st.cache_data
61
  def analyze_text(text):
62
+ if ih:
63
+ res = None
64
+ with rc:
65
+ with st.spinner("Processing...", show_time=True) as spnr:
66
+ time.sleep(5)
67
+ res = ih.classify_text(text)
68
+ del spnr
69
+
70
+ if res is not None:
71
+ st.session_state.results.append(res)
72
+ history_df.loc[-1] = extract_data(res)
73
+ output_results(res)
74
 
75
  st.title('NLPinitiative Text Classifier')
76
+
77
+ st.sidebar.write("")
78
+ API_KEY = st.sidebar.text_input(
79
+ "Enter your HuggingFace API Token",
80
+ help="You can get your free API token in your settings page: https://huggingface.co/settings/tokens",
81
+ type="password",
82
+ )
83
+
84
+ try:
85
+ if API_KEY is not None and len(API_KEY) > 0:
86
+ ih = InferenceHandler(API_KEY)
87
+ except:
88
+ ih = None
89
+ st.error('Invalid Token')
90
+
91
  tab1, tab2 = st.tabs(['Classifier', 'About This App'])
92
 
93
  if "results" not in st.session_state:
 
98
  with tab1:
99
  "Text Classifier for determining if entered text is discriminatory (and the categories of discrimination) or Non-Discriminatory."
100
 
101
+ hist_container = st.container()
102
+ hist_expander = hist_container.expander('History')
103
+ rc = st.container()
104
+
 
 
105
  text_form = st.form(key='classifier', clear_on_submit=True, enter_to_submit=True)
106
  with text_form:
107
+ text_area = st.text_area('Enter text to classify', value='', disabled=True if ih is None else False)
108
+ form_btn = st.form_submit_button('submit', disabled=True if ih is None else False)
109
 
110
  if entry := text_area:
111
+ st.write(f'TEXT AREA: {entry}')
112
+ if entry and len(entry) > 0:
113
+ analyze_text(entry)
114
+ entry = None
115
 
116
+ with hist_expander:
117
+ st.dataframe(history_df)
118
 
119
  with tab2:
120
  st.markdown(
config.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Used for setting some constants for the project codebase
2
+
3
+ from pathlib import Path
4
+
5
+ # Root Path
6
+ ROOT = Path(__file__).resolve().parents[0]
7
+
8
+ # Model Directory
9
+ MODELS_DIR = ROOT / 'models'
10
+
11
+ # Binary Classification Model Path
12
+ BIN_MODEL_PATH = MODELS_DIR / 'binary_classification'
13
+
14
+ # Multilabel Regression Model Path
15
+ ML_MODEL_PATH = MODELS_DIR / 'multilabel_regression'
16
+
17
+ # HF Hub Repositories
18
+ BIN_REPO = 'dlsmallw/Binary-Classification-testing'
19
+ ML_REPO = 'dlsmallw/Multilabel-Regression-testing'
20
+
21
+ BIN_API_URL = f"https://api-inference.huggingface.co/models/{BIN_REPO}"
22
+ ML_API_URL = f"https://api-inference.huggingface.co/models/{ML_REPO}"
scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import config
scripts/predict.py CHANGED
@@ -5,30 +5,42 @@ Script file used for performing inference with an existing model.
5
  from pathlib import Path
6
  import torch
7
  import json
 
 
 
8
 
9
  from transformers import (
10
  AutoTokenizer,
11
  AutoModelForSequenceClassification
12
  )
13
 
 
 
14
 
15
  ## Class used to encapsulate and handle the logic for inference
16
  class InferenceHandler:
17
- def __init__(self, bin_model_path: Path, ml_regr_model_path: Path):
18
- self.bin_tokenizer, self.bin_model = self.init_model_and_tokenizer(bin_model_path)
19
- self.ml_regr_tokenizer, self.ml_regr_model = self.init_model_and_tokenizer(ml_regr_model_path)
 
 
 
 
 
 
 
20
 
21
  ## Initializes a model and tokenizer for use in inference using the models path
22
- def init_model_and_tokenizer(self, model_path: Path):
23
- with open(model_path / 'config.json') as config_file:
 
24
  config_json = json.load(config_file)
25
  model_name = config_json['_name_or_path']
26
- model_type = config_json['model_type']
27
 
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
29
- model = AutoModelForSequenceClassification.from_pretrained(model_path, model_type=model_type)
30
- model.eval()
31
 
 
32
  return tokenizer, model
33
 
34
  ## Handles logic used to encode the text for use in binary classification
@@ -103,3 +115,15 @@ class InferenceHandler:
103
  results.append(max(0.0, item))
104
 
105
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from pathlib import Path
6
  import torch
7
  import json
8
+ import huggingface_hub
9
+ import joblib
10
+ from huggingface_hub.inference_api import InferenceApi
11
 
12
  from transformers import (
13
  AutoTokenizer,
14
  AutoModelForSequenceClassification
15
  )
16
 
17
+ BIN_REPO = 'dlsmallw/Binary-Classification-testing'
18
+ ML_REPO = 'dlsmallw/Multilabel-Regression-testing'
19
 
20
  ## Class used to encapsulate and handle the logic for inference
21
  class InferenceHandler:
22
+ def __init__(self, api_token):
23
+ self.api_token = api_token
24
+ self.bin_tokenizer, self.bin_model = self.init_model_and_tokenizer(BIN_REPO)
25
+ self.ml_regr_tokenizer, self.ml_regr_model = self.init_model_and_tokenizer(ML_REPO)
26
+
27
+ def get_config(self, repo_id):
28
+ config = None
29
+ if repo_id and self.api_token:
30
+ config = huggingface_hub.hf_hub_download(repo_id, filename='config.json', token=self.api_token)
31
+ return config
32
 
33
  ## Initializes a model and tokenizer for use in inference using the models path
34
+ def init_model_and_tokenizer(self, repo_id):
35
+ config = self.get_config(repo_id)
36
+ with open(config) as config_file:
37
  config_json = json.load(config_file)
38
  model_name = config_json['_name_or_path']
 
39
 
40
  tokenizer = AutoTokenizer.from_pretrained(model_name)
41
+ model = AutoModelForSequenceClassification.from_pretrained(repo_id, token=self.api_token)
 
42
 
43
+ model.eval()
44
  return tokenizer, model
45
 
46
  ## Handles logic used to encode the text for use in binary classification
 
115
  results.append(max(0.0, item))
116
 
117
  return results
118
+
119
+ ih = InferenceHandler('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')
120
+ discr_text = """Islam is hell . We Asian hate isalm...Islam bring only destruction..killing. rape... rufugee.... Islam is hell .. Asian thought 😂😂sorry...but is true"""
121
+
122
+ nondiscr_text = """Absolutely excellent as ever, and yes, cycling is fabulous for mental health.. bike when you'€™re happy,
123
+ bike when you'€™re sad, shout at the rain, scream inwardly at a complex hill climb and get that endorphin rush from being outside..
124
+ and smile at your fellow cyclists (and passing chickens, in my experience, a wee bwark at a chook makes many things better)"""
125
+
126
+ result1 = ih.classify_text(discr_text)
127
+ result2 = ih.classify_text(nondiscr_text)
128
+ print(json.dumps(result1, indent=4))
129
+ print(json.dumps(result2, indent=4))