dlsmallw commited on
Commit
82221ca
·
1 Parent(s): 377b74f

Task-290 Integrate use of models within the application for use in inference

Browse files
Files changed (4) hide show
  1. Pipfile +2 -0
  2. Pipfile.lock +212 -4
  3. app.py +31 -25
  4. scripts/predict.py +105 -0
Pipfile CHANGED
@@ -8,6 +8,8 @@ streamlit = "*"
8
  pandas = "*"
9
  numpy = "*"
10
  st-annotated-text = "*"
 
 
11
 
12
  [dev-packages]
13
 
 
8
  pandas = "*"
9
  numpy = "*"
10
  st-annotated-text = "*"
11
+ transformers = "*"
12
+ torch = "*"
13
 
14
  [dev-packages]
15
 
Pipfile.lock CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_meta": {
3
  "hash": {
4
- "sha256": "b114c11edb911ee8c918ef91522f2a4e6895019b452b37a052da3b4deb4c456e"
5
  },
6
  "pipfile-spec": 6,
7
  "requires": {
@@ -170,6 +170,22 @@
170
  "markers": "platform_system == 'Windows'",
171
  "version": "==0.4.6"
172
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  "gitdb": {
174
  "hashes": [
175
  "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571",
@@ -193,6 +209,14 @@
193
  "markers": "python_version >= '3.7'",
194
  "version": "==0.9.0"
195
  },
 
 
 
 
 
 
 
 
196
  "idna": {
197
  "hashes": [
198
  "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9",
@@ -308,13 +332,28 @@
308
  "markers": "python_version >= '3.7'",
309
  "version": "==0.1.2"
310
  },
 
 
 
 
 
 
 
311
  "narwhals": {
312
  "hashes": [
313
- "sha256:45d909ad6240944d447b0dae38074c5a919830dff3868d57b05a5526c1f06fe4",
314
- "sha256:a2213fa44a039f724278fb15609889319e7c240403413f2606cc856c8d8f708d"
315
  ],
316
  "markers": "python_version >= '3.8'",
317
- "version": "==1.28.0"
 
 
 
 
 
 
 
 
318
  },
319
  "numpy": {
320
  "hashes": [
@@ -608,6 +647,65 @@
608
  ],
609
  "version": "==2025.1"
610
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  "referencing": {
612
  "hashes": [
613
  "sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa",
@@ -616,6 +714,14 @@
616
  "markers": "python_version >= '3.9'",
617
  "version": "==0.36.2"
618
  },
 
 
 
 
 
 
 
 
619
  "requests": {
620
  "hashes": [
621
  "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760",
@@ -741,6 +847,35 @@
741
  "markers": "python_version >= '3.9'",
742
  "version": "==0.23.1"
743
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
  "six": {
745
  "hashes": [
746
  "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274",
@@ -775,6 +910,14 @@
775
  "markers": "python_version >= '3.9' and python_full_version != '3.9.7'",
776
  "version": "==1.42.2"
777
  },
 
 
 
 
 
 
 
 
778
  "tenacity": {
779
  "hashes": [
780
  "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b",
@@ -783,6 +926,27 @@
783
  "markers": "python_version >= '3.8'",
784
  "version": "==9.0.0"
785
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
  "toml": {
787
  "hashes": [
788
  "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b",
@@ -791,6 +955,33 @@
791
  "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2'",
792
  "version": "==0.10.2"
793
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
  "tornado": {
795
  "hashes": [
796
  "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803",
@@ -808,6 +999,23 @@
808
  "markers": "python_version >= '3.8'",
809
  "version": "==6.4.2"
810
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
  "typing-extensions": {
812
  "hashes": [
813
  "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d",
 
1
  {
2
  "_meta": {
3
  "hash": {
4
+ "sha256": "d44f8f17557914a1bc97b5e9ce219979a85e81b74eb603b3c0c6920cac065c91"
5
  },
6
  "pipfile-spec": 6,
7
  "requires": {
 
170
  "markers": "platform_system == 'Windows'",
171
  "version": "==0.4.6"
172
  },
173
+ "filelock": {
174
+ "hashes": [
175
+ "sha256:533dc2f7ba78dc2f0f531fc6c4940addf7b70a481e269a5a3b93be94ffbe8338",
176
+ "sha256:ee4e77401ef576ebb38cd7f13b9b28893194acc20a8e68e18730ba9c0e54660e"
177
+ ],
178
+ "markers": "python_version >= '3.9'",
179
+ "version": "==3.17.0"
180
+ },
181
+ "fsspec": {
182
+ "hashes": [
183
+ "sha256:1c24b16eaa0a1798afa0337aa0db9b256718ab2a89c425371f5628d22c3b6afd",
184
+ "sha256:9de2ad9ce1f85e1931858535bc882543171d197001a0a5eb2ddc04f1781ab95b"
185
+ ],
186
+ "markers": "python_version >= '3.8'",
187
+ "version": "==2025.2.0"
188
+ },
189
  "gitdb": {
190
  "hashes": [
191
  "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571",
 
209
  "markers": "python_version >= '3.7'",
210
  "version": "==0.9.0"
211
  },
212
+ "huggingface-hub": {
213
+ "hashes": [
214
+ "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5",
215
+ "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250"
216
+ ],
217
+ "markers": "python_full_version >= '3.8.0'",
218
+ "version": "==0.29.1"
219
+ },
220
  "idna": {
221
  "hashes": [
222
  "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9",
 
332
  "markers": "python_version >= '3.7'",
333
  "version": "==0.1.2"
334
  },
335
+ "mpmath": {
336
+ "hashes": [
337
+ "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f",
338
+ "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"
339
+ ],
340
+ "version": "==1.3.0"
341
+ },
342
  "narwhals": {
343
  "hashes": [
344
+ "sha256:1021c345d56c66ff0cc8e6d03ca8c543d01ffc411630973a5cb69ee86824d823",
345
+ "sha256:653aa8e5eb435816e7b50c8def17e7e5e3324c2ffd8a3eec03fef85792e9cf5e"
346
  ],
347
  "markers": "python_version >= '3.8'",
348
+ "version": "==1.29.0"
349
+ },
350
+ "networkx": {
351
+ "hashes": [
352
+ "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1",
353
+ "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f"
354
+ ],
355
+ "markers": "python_version >= '3.10'",
356
+ "version": "==3.4.2"
357
  },
358
  "numpy": {
359
  "hashes": [
 
647
  ],
648
  "version": "==2025.1"
649
  },
650
+ "pyyaml": {
651
+ "hashes": [
652
+ "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff",
653
+ "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48",
654
+ "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086",
655
+ "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e",
656
+ "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133",
657
+ "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5",
658
+ "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484",
659
+ "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee",
660
+ "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5",
661
+ "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68",
662
+ "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a",
663
+ "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf",
664
+ "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99",
665
+ "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8",
666
+ "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85",
667
+ "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19",
668
+ "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc",
669
+ "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a",
670
+ "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1",
671
+ "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317",
672
+ "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c",
673
+ "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631",
674
+ "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d",
675
+ "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652",
676
+ "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5",
677
+ "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e",
678
+ "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b",
679
+ "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8",
680
+ "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476",
681
+ "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706",
682
+ "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563",
683
+ "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237",
684
+ "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b",
685
+ "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083",
686
+ "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180",
687
+ "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425",
688
+ "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e",
689
+ "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f",
690
+ "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725",
691
+ "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183",
692
+ "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab",
693
+ "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774",
694
+ "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725",
695
+ "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e",
696
+ "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5",
697
+ "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d",
698
+ "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290",
699
+ "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44",
700
+ "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed",
701
+ "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4",
702
+ "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba",
703
+ "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12",
704
+ "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"
705
+ ],
706
+ "markers": "python_version >= '3.8'",
707
+ "version": "==6.0.2"
708
+ },
709
  "referencing": {
710
  "hashes": [
711
  "sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa",
 
714
  "markers": "python_version >= '3.9'",
715
  "version": "==0.36.2"
716
  },
717
+ "regex": {
718
+ "hashes": [
719
+ "sha256:7ab159b063c52a0333c884e4679f8d7a85112ee3078fe3d9004b2dd875585519",
720
+ "sha256:a93c194e2df18f7d264092dc8539b8ffb86b45b899ab976aa15d48214138e81b"
721
+ ],
722
+ "markers": "python_version >= '3.8'",
723
+ "version": "==2024.11.6"
724
+ },
725
  "requests": {
726
  "hashes": [
727
  "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760",
 
847
  "markers": "python_version >= '3.9'",
848
  "version": "==0.23.1"
849
  },
850
+ "safetensors": {
851
+ "hashes": [
852
+ "sha256:1077f3e94182d72618357b04b5ced540ceb71c8a813d3319f1aba448e68a770d",
853
+ "sha256:11bce6164887cd491ca75c2326a113ba934be596e22b28b1742ce27b1d076467",
854
+ "sha256:21d01c14ff6c415c485616b8b0bf961c46b3b343ca59110d38d744e577f9cce7",
855
+ "sha256:32c3ef2d7af8b9f52ff685ed0bc43913cdcde135089ae322ee576de93eae5135",
856
+ "sha256:37f1521be045e56fc2b54c606d4455573e717b2d887c579ee1dbba5f868ece04",
857
+ "sha256:391ac8cab7c829452175f871fcaf414aa1e292b5448bd02620f675a7f3e7abb9",
858
+ "sha256:4a243be3590bc3301c821da7a18d87224ef35cbd3e5f5727e4e0728b8172411e",
859
+ "sha256:799021e78287bac619c7b3f3606730a22da4cda27759ddf55d37c8db7511c74b",
860
+ "sha256:836cbbc320b47e80acd40e44c8682db0e8ad7123209f69b093def21ec7cafd11",
861
+ "sha256:8bd84b12b1670a6f8e50f01e28156422a2bc07fb16fc4e98bded13039d688a0d",
862
+ "sha256:b6b0d6ecacec39a4fdd99cc19f4576f5219ce858e6fd8dbe7609df0b8dc56965",
863
+ "sha256:bd20eb133db8ed15b40110b7c00c6df51655a2998132193de2f75f72d99c7073",
864
+ "sha256:cead1fa41fc54b1e61089fa57452e8834f798cb1dc7a09ba3524f1eb08e0317a",
865
+ "sha256:cfc0ec0846dcf6763b0ed3d1846ff36008c6e7290683b61616c4b040f6a54ace",
866
+ "sha256:df26da01aaac504334644e1b7642fa000bfec820e7cef83aeac4e355e03195ff"
867
+ ],
868
+ "markers": "python_version >= '3.7'",
869
+ "version": "==0.5.3"
870
+ },
871
+ "setuptools": {
872
+ "hashes": [
873
+ "sha256:4880473a969e5f23f2a2be3646b2dfd84af9028716d398e46192f84bc36900d2",
874
+ "sha256:558e47c15f1811c1fa7adbd0096669bf76c1d3f433f58324df69f3f5ecac4e8f"
875
+ ],
876
+ "markers": "python_version >= '3.12'",
877
+ "version": "==75.8.2"
878
+ },
879
  "six": {
880
  "hashes": [
881
  "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274",
 
910
  "markers": "python_version >= '3.9' and python_full_version != '3.9.7'",
911
  "version": "==1.42.2"
912
  },
913
+ "sympy": {
914
+ "hashes": [
915
+ "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f",
916
+ "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8"
917
+ ],
918
+ "markers": "python_version >= '3.9'",
919
+ "version": "==1.13.1"
920
+ },
921
  "tenacity": {
922
  "hashes": [
923
  "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b",
 
926
  "markers": "python_version >= '3.8'",
927
  "version": "==9.0.0"
928
  },
929
+ "tokenizers": {
930
+ "hashes": [
931
+ "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b",
932
+ "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2",
933
+ "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273",
934
+ "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff",
935
+ "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193",
936
+ "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e",
937
+ "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c",
938
+ "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e",
939
+ "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74",
940
+ "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba",
941
+ "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04",
942
+ "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a",
943
+ "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e",
944
+ "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4",
945
+ "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e"
946
+ ],
947
+ "markers": "python_version >= '3.7'",
948
+ "version": "==0.21.0"
949
+ },
950
  "toml": {
951
  "hashes": [
952
  "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b",
 
955
  "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2'",
956
  "version": "==0.10.2"
957
  },
958
+ "torch": {
959
+ "hashes": [
960
+ "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628",
961
+ "sha256:265f70de5fd45b864d924b64be1797f86e76c8e48a02c2a3a6fc7ec247d2226c",
962
+ "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9",
963
+ "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7",
964
+ "sha256:4874a73507a300a5d089ceaff616a569e7bb7c613c56f37f63ec3ffac65259cf",
965
+ "sha256:510c73251bee9ba02ae1cb6c9d4ee0907b3ce6020e62784e2d7598e0cfa4d6cc",
966
+ "sha256:56eeaf2ecac90da5d9e35f7f35eb286da82673ec3c582e310a8d1631a1c02341",
967
+ "sha256:683410f97984103148e31b38a8631acf31c3034c020c0f4d26171e7626d8317a",
968
+ "sha256:6860df13d9911ac158f4c44031609700e1eba07916fff62e21e6ffa0a9e01961",
969
+ "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1",
970
+ "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239",
971
+ "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21",
972
+ "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989",
973
+ "sha256:9ea955317cfcd3852b1402b62af258ce735c2edeee42ca9419b6bc889e5ae053",
974
+ "sha256:a0d5e1b9874c1a6c25556840ab8920569a7a4137afa8a63a32cee0bc7d89bd4b",
975
+ "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb",
976
+ "sha256:bb2c6c3e65049f081940f5ab15c9136c7de40d3f01192541c920a07c7c585b7e",
977
+ "sha256:c4f103a49830ce4c7561ef4434cc7926e5a5fe4e5eb100c19ab36ea1e2b634ab",
978
+ "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d",
979
+ "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2"
980
+ ],
981
+ "index": "pypi",
982
+ "markers": "python_full_version >= '3.9.0'",
983
+ "version": "==2.6.0"
984
+ },
985
  "tornado": {
986
  "hashes": [
987
  "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803",
 
999
  "markers": "python_version >= '3.8'",
1000
  "version": "==6.4.2"
1001
  },
1002
+ "tqdm": {
1003
+ "hashes": [
1004
+ "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2",
1005
+ "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"
1006
+ ],
1007
+ "markers": "python_version >= '3.7'",
1008
+ "version": "==4.67.1"
1009
+ },
1010
+ "transformers": {
1011
+ "hashes": [
1012
+ "sha256:6b4fded1c5fee04d384b1014495b4235a2b53c87503d7d592423c06128cbbe03",
1013
+ "sha256:7e40e640b5b8dc3f48743f5f5adbdce3660c82baafbd3afdfc04143cdbd2089e"
1014
+ ],
1015
+ "index": "pypi",
1016
+ "markers": "python_full_version >= '3.9.0'",
1017
+ "version": "==4.49.0"
1018
+ },
1019
  "typing-extensions": {
1020
  "hashes": [
1021
  "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d",
app.py CHANGED
@@ -3,8 +3,17 @@ import pandas as pd
3
  from annotated_text import annotated_text, annotation
4
  import time
5
  from random import randint, uniform
 
 
 
 
 
 
 
 
6
 
7
  history_df = pd.DataFrame(data=[], columns=['Text', 'Classification', 'Gender', 'Race', 'Sexuality', 'Disability', 'Religion', 'Unspecified'])
 
8
 
9
  def extract_data(json_obj):
10
  row_data = []
@@ -40,8 +49,8 @@ def output_results(res):
40
  if res['numerical_sentiment'] == 1:
41
  # st.markdown('##### Category Results:')
42
  for entry in res['category_sentiments'].keys():
43
- if randint(0, 1) == 1:
44
- val = res['category_sentiments'][entry]
45
  perc = val * 100
46
  at_list.append((entry, f'{perc:.2f}%', label_dict[entry]))
47
 
@@ -49,28 +58,25 @@ def output_results(res):
49
  st.markdown(f"#### Classification - {':red' if res['numerical_sentiment'] == 1 else ':green'}[{res['text_sentiment']}]")
50
 
51
  if len(at_list) > 0:
52
- st.markdown('#### Categories: ')
53
- cols = st.columns([1, 15])
54
- with cols[1]:
55
- for cat in at_list:
56
- annotated_text(cat)
57
-
58
- def test_results(text):
59
- test_val = int(randint(0, 1))
60
- res_obj = {
61
- 'raw_text': text,
62
- 'text_sentiment': 'Discriminatory' if test_val == 1 else 'Non-Discriminatory',
63
- 'numerical_sentiment': test_val,
64
- 'category_sentiments': {
65
- 'Gender': None if test_val == 0 else uniform(0.0, 1.0),
66
- 'Race': None if test_val == 0 else uniform(0.0, 1.0),
67
- 'Sexuality': None if test_val == 0 else uniform(0.0, 1.0),
68
- 'Disability': None if test_val == 0 else uniform(0.0, 1.0),
69
- 'Religion': None if test_val == 0 else uniform(0.0, 1.0),
70
- 'Unspecified': None if test_val == 0 else uniform(0.0, 1.0)
71
- }
72
- }
73
- return res_obj
74
 
75
 
76
  def analyze_text(text):
@@ -78,7 +84,7 @@ def analyze_text(text):
78
  with rc:
79
  with st.spinner("Processing...", show_time=True) as spnr:
80
  time.sleep(5)
81
- res = test_results(text)
82
  del spnr
83
 
84
  if res is not None:
 
3
  from annotated_text import annotated_text, annotation
4
  import time
5
  from random import randint, uniform
6
+ from scripts.predict import InferenceHandler
7
+ from pathlib import Path
8
+
9
+ ROOT = Path(__file__).resolve().parents[0]
10
+ st.write(ROOT)
11
+ MODELS_DIR = ROOT / 'models'
12
+ BIN_MODEL_PATH = MODELS_DIR / 'binary_classification'
13
+ ML_MODEL_PATH = MODELS_DIR / 'multilabel_regression'
14
 
15
  history_df = pd.DataFrame(data=[], columns=['Text', 'Classification', 'Gender', 'Race', 'Sexuality', 'Disability', 'Religion', 'Unspecified'])
16
+ ih = InferenceHandler(BIN_MODEL_PATH, ML_MODEL_PATH)
17
 
18
  def extract_data(json_obj):
19
  row_data = []
 
49
  if res['numerical_sentiment'] == 1:
50
  # st.markdown('##### Category Results:')
51
  for entry in res['category_sentiments'].keys():
52
+ val = res['category_sentiments'][entry]
53
+ if val > 0.0:
54
  perc = val * 100
55
  at_list.append((entry, f'{perc:.2f}%', label_dict[entry]))
56
 
 
58
  st.markdown(f"#### Classification - {':red' if res['numerical_sentiment'] == 1 else ':green'}[{res['text_sentiment']}]")
59
 
60
  if len(at_list) > 0:
61
+ annotated_text(at_list)
62
+
63
+
64
+ # def test_results(text):
65
+ # test_val = int(randint(0, 1))
66
+ # res_obj = {
67
+ # 'raw_text': text,
68
+ # 'text_sentiment': 'Discriminatory' if test_val == 1 else 'Non-Discriminatory',
69
+ # 'numerical_sentiment': test_val,
70
+ # 'category_sentiments': {
71
+ # 'Gender': None if test_val == 0 else uniform(0.0, 1.0),
72
+ # 'Race': None if test_val == 0 else uniform(0.0, 1.0),
73
+ # 'Sexuality': None if test_val == 0 else uniform(0.0, 1.0),
74
+ # 'Disability': None if test_val == 0 else uniform(0.0, 1.0),
75
+ # 'Religion': None if test_val == 0 else uniform(0.0, 1.0),
76
+ # 'Unspecified': None if test_val == 0 else uniform(0.0, 1.0)
77
+ # }
78
+ # }
79
+ # return res_obj
 
 
 
80
 
81
 
82
  def analyze_text(text):
 
84
  with rc:
85
  with st.spinner("Processing...", show_time=True) as spnr:
86
  time.sleep(5)
87
+ res = ih.classify_text(text)
88
  del spnr
89
 
90
  if res is not None:
scripts/predict.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script file used for performing inference with an existing model.
3
+ """
4
+
5
+ from pathlib import Path
6
+ import torch
7
+ import json
8
+
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForSequenceClassification
12
+ )
13
+
14
+
15
+ ## Class used to encapsulate and handle the logic for inference
16
+ class InferenceHandler:
17
+ def __init__(self, bin_model_path: Path, ml_regr_model_path: Path):
18
+ self.bin_tokenizer, self.bin_model = self.init_model_and_tokenizer(bin_model_path)
19
+ self.ml_regr_tokenizer, self.ml_regr_model = self.init_model_and_tokenizer(ml_regr_model_path)
20
+
21
+ ## Initializes a model and tokenizer for use in inference using the models path
22
+ def init_model_and_tokenizer(self, model_path: Path):
23
+ with open(model_path / 'config.json') as config_file:
24
+ config_json = json.load(config_file)
25
+ model_name = config_json['_name_or_path']
26
+ model_type = config_json['model_type']
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ model = AutoModelForSequenceClassification.from_pretrained(model_path, model_type=model_type)
30
+ model.eval()
31
+
32
+ return tokenizer, model
33
+
34
+ ## Handles logic used to encode the text for use in binary classification
35
+ def encode_binary(self, text):
36
+ bin_tokenized_input = self.bin_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
37
+ return bin_tokenized_input
38
+
39
+ ## Handles logic used to encode the text for use in multilabel regression
40
+ def encode_multilabel(self, text):
41
+ ml_tokenized_input = self.ml_regr_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
42
+ return ml_tokenized_input
43
+
44
+ ## Handles text encoding for both binary classification and multilabel regression
45
+ def encode_input(self, text):
46
+ bin_inputs = self.encode_binary(text)
47
+ ml_inputs = self.encode_multilabel(text)
48
+ return bin_inputs, ml_inputs
49
+
50
+ ## Handles performing the full sentiment analysis (binary classification and multilabel regression)
51
+ def classify_text(self, text):
52
+ res_obj = {
53
+ 'raw_text': text,
54
+ 'text_sentiment': None,
55
+ 'numerical_sentiment': None,
56
+ 'category_sentiments': {
57
+ 'Gender': None,
58
+ 'Race': None,
59
+ 'Sexuality': None,
60
+ 'Disability': None,
61
+ 'Religion': None,
62
+ 'Unspecified': None
63
+ }
64
+ }
65
+
66
+ text_prediction, pred_class = self.discriminatory_inference(text)
67
+ res_obj['text_sentiment'] = text_prediction
68
+ res_obj['numerical_sentiment'] = pred_class
69
+
70
+ if pred_class == 1:
71
+ ml_infer_results = self.category_inference(text)
72
+
73
+ for idx, key in enumerate(res_obj['category_sentiments'].keys()):
74
+ res_obj['category_sentiments'][key] = ml_infer_results[idx]
75
+
76
+ return res_obj
77
+
78
+ ## Handles logic for checking the binary classfication of the text
79
+ def discriminatory_inference(self, text):
80
+ bin_inputs = self.encode_binary(text)
81
+
82
+ with torch.no_grad():
83
+ bin_logits = self.bin_model(**bin_inputs).logits
84
+
85
+ probs = torch.nn.functional.softmax(bin_logits, dim=-1)
86
+ pred_class = torch.argmax(probs).item()
87
+ bin_label_map = {0: "Non-Discriminatory", 1: "Discriminatory"}
88
+ bin_text_pred = bin_label_map[pred_class]
89
+
90
+ return bin_text_pred, pred_class
91
+
92
+ ## Handles logic for assessing the categories of discrimination
93
+ def category_inference(self, text):
94
+ ml_inputs = self.encode_multilabel(text)
95
+
96
+ with torch.no_grad():
97
+ ml_outputs = self.ml_regr_model(**ml_inputs).logits
98
+
99
+ ml_op_list = ml_outputs.squeeze().tolist()
100
+
101
+ results = []
102
+ for item in ml_op_list:
103
+ results.append(max(0.0, item))
104
+
105
+ return results