Spaces:
Runtime error
Runtime error
Commit
·
6554f2c
1
Parent(s):
b6bfb5a
Initial commit
Browse files- app.py +21 -0
- apps/__pycache__/about.cpython-38.pyc +0 -0
- apps/__pycache__/credits.cpython-38.pyc +0 -0
- apps/__pycache__/inference.cpython-38.pyc +0 -0
- apps/about.py +51 -0
- apps/credits.py +51 -0
- apps/inference.py +52 -0
- mlm_custom/mlm_full_text.csv +19 -0
- mlm_custom/mlm_targeted_text.csv +18 -0
- mlm_custom/mlm_test_config.csv +6 -0
- mlm_custom/test_mlm.py +142 -0
- multiapp.py +22 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pandas.io.formats.format import return_docstring
|
2 |
+
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
+
from transformers import AutoTokenizer,AutoModelForMaskedLM
|
5 |
+
from transformers import pipeline
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
from multiapp import MultiApp
|
9 |
+
from apps import about,credits,inference
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
app = MultiApp()
|
14 |
+
app.add_app("Inference", inference.app)
|
15 |
+
app.add_app("About", about.app)
|
16 |
+
app.add_app("Credits", credits.app)
|
17 |
+
app.run()
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
main()
|
apps/__pycache__/about.cpython-38.pyc
ADDED
Binary file (2.95 kB). View file
|
|
apps/__pycache__/credits.cpython-38.pyc
ADDED
Binary file (1.96 kB). View file
|
|
apps/__pycache__/inference.cpython-38.pyc
ADDED
Binary file (1.68 kB). View file
|
|
apps/about.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author: prateek
|
3 |
+
# @Date: 2021-03-02 02:23:36
|
4 |
+
# @Last Modified by: prateek
|
5 |
+
# @Last Modified time: 2021-03-02 23:04:21
|
6 |
+
|
7 |
+
import streamlit as st
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
from sklearn import datasets
|
11 |
+
from PIL import Image
|
12 |
+
def app():
|
13 |
+
st.title('About')
|
14 |
+
st.write("""
|
15 |
+
## What is diabetes
|
16 |
+
|
17 |
+
According to the NIH, "Diabetes is a disease that occurs when your **blood glucose**,
|
18 |
+
also called blood sugar, is **too high**. Blood **glucose** is your main source of
|
19 |
+
energy and **comes from the food you eat**. **Insulin**, a hormone made from the pancreas,
|
20 |
+
**helps glucose** from food get into your cells to be used for energy. Sometimes your
|
21 |
+
body doesn’t make enough or any insulin or doesn’t use insulin well. Glucose then stays
|
22 |
+
in your blood and doesn’t reach your cells.
|
23 |
+
Over time, **having too much glucose in your blood** can cause health problems. """)
|
24 |
+
|
25 |
+
st.write(
|
26 |
+
"""
|
27 |
+
### Health impact
|
28 |
+
Over time, diabetes can damage the heart, blood vessels, eyes, kidneys, and nerves.
|
29 |
+
|
30 |
+
* Adults with diabetes have a two- to three-fold increased risk of heart attacks and strokes(1).
|
31 |
+
* Combined with reduced blood flow, neuropathy (nerve damage) in the feet increases the chance of foot ulcers, infection and eventual need for limb amputation.
|
32 |
+
* Diabetic retinopathy is an important cause of blindness, and occurs as a result of long-term accumulated damage to the small blood vessels in the retina. Diabetes is the cause of 2.6% of global blindness(2).
|
33 |
+
* Diabetes is among the leading causes of kidney failure(3).
|
34 |
+
""")
|
35 |
+
|
36 |
+
st.write(
|
37 |
+
"""
|
38 |
+
### Key facts
|
39 |
+
* The number of people with diabetes rose from 108 million in 1980 to 422 million in 2014.
|
40 |
+
* The global prevalence of diabetes* among adults over 18 years of age rose from 4.7% in 1980 to 8.5% in 2014 (1).
|
41 |
+
* Between 2000 and 2016, there was a 5% increase in premature mortality from diabetes.
|
42 |
+
* Diabetes prevalence has been rising more rapidly in low- and middle-income countries than in high-income countries.
|
43 |
+
* Diabetes is a major cause of blindness, kidney failure, heart attacks, stroke and lower limb amputation.
|
44 |
+
* In 2016, an estimated 1.6 million deaths were directly caused by diabetes. Another 2.2 million deaths were attributable to high blood glucose in 2012.
|
45 |
+
* Almost half of all deaths attributable to high blood glucose occur before the age of 70 years. WHO estimates that diabetes was the seventh leading cause of death in 2016.
|
46 |
+
* A healthy diet, regular physical activity, maintaining a normal body weight and avoiding tobacco use are ways to prevent or delay the onset of type 2 diabetes.
|
47 |
+
* Diabetes can be treated and its consequences avoided or delayed with diet, physical activity, medication and regular screening and treatment for complications.
|
48 |
+
""")
|
49 |
+
|
50 |
+
|
51 |
+
|
apps/credits.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author: prateek
|
3 |
+
# @Date: 2021-03-02 22:37:41
|
4 |
+
# @Last Modified by: prateek
|
5 |
+
# @Last Modified time: 2021-03-02 23:38:33
|
6 |
+
|
7 |
+
import streamlit as st
|
8 |
+
|
9 |
+
def app():
|
10 |
+
st.title(' Credits')
|
11 |
+
|
12 |
+
st.write("""The following web application is built and maintained by **Prateek Agrawal** for the sole purpose of learning and displaying the power and usage of machine learning in the field of healthcare. He believes that the Artificial Intelligence and Machine Learning can truly help in making the world a better place to live in.""")
|
13 |
+
|
14 |
+
st.write("""
|
15 |
+
|
16 |
+
## Data
|
17 |
+
|
18 |
+
The datasets consist of several medical predictor (independent) variables and one target (dependent)
|
19 |
+
variable, Outcome. Independent variables include the number of pregnancies the patient has had,
|
20 |
+
their BMI, insulin level, age, and so on.
|
21 |
+
[link of data in kaggle](https://www.kaggle.com/uciml/pima-indians-diabetes-database)""")
|
22 |
+
st.write("""
|
23 |
+
|
24 |
+
## Columns
|
25 |
+
|
26 |
+
|Columns|Description|
|
27 |
+
|-------|------------|
|
28 |
+
|Pregnancies|Number of times pregnant|
|
29 |
+
|Glucose|Plasma glucose concentration for 2 hours in an oral glucose tolerance test|
|
30 |
+
|BloodPressure|Diastolic blood pressure (mm Hg)|
|
31 |
+
|SkinThickness|Triceps skin fold thickness (mm)|
|
32 |
+
|Insulin|2-Hour serum insulin (mu U/ml)|
|
33 |
+
|BMI|Body mass index (weight in kg/(height in m)^2)|
|
34 |
+
|DiabetesPedigreeFunction|Diabetes pedigree function|
|
35 |
+
|Age|Age (years)|
|
36 |
+
|Outcome|Class variable (0 or 1) 268 of 768 are 1, the others are 0|
|
37 |
+
|
38 |
+
|
39 |
+
## Information
|
40 |
+
|
41 |
+
### WHO Website
|
42 |
+
* https://www.who.int/health-topics/diabetes#tab=tab_1
|
43 |
+
* https://www.who.int/news-room/fact-sheets/detail/diabetes
|
44 |
+
|
45 |
+
### Machine Learning related info
|
46 |
+
|
47 |
+
* https://www.kaggle.com/uciml/pima-indians-diabetes-database/code
|
48 |
+
* https://towardsdatascience.com/streamlit-101-an-in-depth-introduction-fc8aad9492f2
|
49 |
+
|
50 |
+
|
51 |
+
""")
|
apps/inference.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pandas.io.formats.format import return_docstring
|
2 |
+
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
+
from transformers import AutoTokenizer,AutoModelForMaskedLM
|
5 |
+
from transformers import pipeline
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
|
9 |
+
@st.cache(show_spinner=False,persist=True)
|
10 |
+
def load_model(masked_text,model_name):
|
11 |
+
|
12 |
+
model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True)
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
14 |
+
tokenizer.save_pretrained('exported_pytorch_model')
|
15 |
+
model.save_pretrained('exported_pytorch_model')
|
16 |
+
nlp = pipeline('fill-mask', model="exported_pytorch_model")
|
17 |
+
|
18 |
+
result_sentence = nlp(masked_text)
|
19 |
+
|
20 |
+
return result_sentence
|
21 |
+
|
22 |
+
def app():
|
23 |
+
st.markdown("<h1 style='text-align: center; color: green;'>RoBERTa Hindi</h1>", unsafe_allow_html=True)
|
24 |
+
st.markdown(
|
25 |
+
"This demo uses pretrained RoBERTa variants for Mask Language Modelling (MLM)"
|
26 |
+
)
|
27 |
+
|
28 |
+
target_text_path = './mlm_custom/mlm_targeted_text.csv'
|
29 |
+
target_text_df = pd.read_csv(target_text_path)
|
30 |
+
|
31 |
+
texts = target_text_df['text']
|
32 |
+
|
33 |
+
st.markdown("""## Select any of the following text : """)
|
34 |
+
masked_text = st.selectbox('',
|
35 |
+
texts)
|
36 |
+
|
37 |
+
st.write('You selected:', masked_text)
|
38 |
+
|
39 |
+
models = st.multiselect(
|
40 |
+
"Choose models",
|
41 |
+
['flax-community/roberta-hindi','mrm8488/HindiBERTa','ai4bharat/indic-bert',\
|
42 |
+
'neuralspace-reverie/indic-transformers-hi-bert',
|
43 |
+
'surajp/RoBERTa-hindi-guj-san'],
|
44 |
+
["flax-community/roberta-hindi"]
|
45 |
+
)
|
46 |
+
|
47 |
+
selected_model = models[0]
|
48 |
+
|
49 |
+
if st.button('Fill the Mask!'):
|
50 |
+
with st.spinner("Filling the Mask..."):
|
51 |
+
filled_sentence = load_model(masked_text,selected_model)
|
52 |
+
st.write(filled_sentence)
|
mlm_custom/mlm_full_text.csv
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
user_id,text
|
2 |
+
dk-crazydiv,हम आपके सुखद यात्रा की कामना करते हैं
|
3 |
+
dk-crazydiv,मुझे उनसे बात करना बहुत अच्छा लगा
|
4 |
+
dk-crazydiv,"बार बार देखो, हज़ार बार देखो, ये देखने की चीज़ है"
|
5 |
+
dk-crazydiv,ट्रंप कल अहमदाबाद में प्रधानमंत्री मोदी से मुलाकात करने जा रहे हैं
|
6 |
+
dk-crazydiv,बॉम्बे से बैंगलोर की दूरी 500 किलोमीटर है
|
7 |
+
amankhandelia,मधु घट फूटा ही करते हैं लघु जीवन लेकर आए हैं प्याले टूटा ही करते हैं
|
8 |
+
amankhandelia,वर्त्तमान के मोहजाल में आने वाला कल न भुलाएं
|
9 |
+
amankhandelia,भारत में हुए अन्यायों के गवाह है मुंशी प्रेमचंद के उपन्यास
|
10 |
+
amankhandelia,"एक लेखक अपनी कलम तभी उठाता है, जब उसकी संवेदनाओं पर किसी ने चोट की हो"
|
11 |
+
amankhandelia,"मेरा कुछ कहना तब उचित है, जब मुझे सुनना तुम्हारी प्राथमिकता हो, औपचारिकता नहीं"
|
12 |
+
amankhandelia,मरना लगा रहेगा यहाँ जी तो लीजिए
|
13 |
+
amankhandelia,बहुत कम लोग जानते हैं कि वो बहुत कम जानते हैं
|
14 |
+
hassiahk,"जल्दी सोना और जल्दी उठना इंसान को स्वस्थ ,समृद्ध और बुद्धिमान बनाता है"
|
15 |
+
hassiahk,बात ये है कि आप इसे पहले से ही जानते हैं
|
16 |
+
hassiahk,"रोज एक सेब खाओ, डॉक्टर से दूर रहो"
|
17 |
+
hassiahk,किसी पुस्तक को उसके आवरण से मत आंकिए
|
18 |
+
hassiahk,जहा चाह वहा राह
|
19 |
+
hassiahk,सभी अच्छी चीजों का एक अंत होता है
|
mlm_custom/mlm_targeted_text.csv
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
user_id,text,output,multi
|
2 |
+
dk-crazydiv,हम आपके <mask> यात्रा की कामना करते हैं,सुखद,
|
3 |
+
dk-crazydiv,मुझे उनसे बात करना बहुत <mask> लगा,अच्छा,
|
4 |
+
dk-crazydiv,"बार बार देखो, हज़ार बार देखो, ये देखने की <mask> है","[""चीज़"",""बात""]",TRUE
|
5 |
+
dk-crazydiv,ट्रंप कल अहमदाबाद में प्रधानमंत्री मोदी से <mask> करने जा रहे हैं,"[""मुलाकात"",""मिल्ने""]",TRUE
|
6 |
+
dk-crazydiv,बॉम्बे से बैंगलोर की <mask> 500 किलोमीटर है,दूरी,
|
7 |
+
dk-crazydiv,कहने को साथ अपने ये <mask> चलती है,दुनिया,
|
8 |
+
dk-crazydiv,"ये इश्क़ नहीं आसान बस इतना समझ लीजिये, एक आग का दरिया है और <mask> के जाना है",डूब,
|
9 |
+
prateekagrawal,आपका दिन <mask> हो,"[""शुभ"",""अच्छा""]",TRUE
|
10 |
+
prateekagrawal,हिंदी भारत में <mask> जाने वाली भाषाओं में से एक है,"[""बोली"",""सिखाई"",""आधिकारिक""]",TRUE
|
11 |
+
prateekagrawal,शुभ <mask>,"[""प्रभात"",रात्रि"",""यात्रा"",""अवसर""]",TRUE
|
12 |
+
prateekagrawal,इंसान को कभी बुरा नहीं <mask> चाहिए,"[""बोलना"",""देखना"",""सुनाना"",""करना""]",TRUE
|
13 |
+
hassiahk,बात ये है कि आप इसे <mask> से ही जानते हैं,पहले,
|
14 |
+
hassiahk,<mask> पूर्व में उगता है,सूरज,
|
15 |
+
hassiahk,"जल्दी सोना और जल्दी उठना इंसान को स्वस्थ ,समृद्ध और बुद्धिमान <mask> है",बनाता,
|
16 |
+
hassiahk,"रोज एक सेब खाओ, <mask> से दूर रहो",डॉक्टर,
|
17 |
+
hassiahk,किसी पुस्तक को उसके <mask> से मत आंकिए,आवरण,
|
18 |
+
hassiahk,सभी <mask> चीजों का एक अंत होता है,"[""अच्छी"", ""बुरी""]",TRUE
|
mlm_custom/mlm_test_config.csv
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name,display_name,revision,from_flax,use_fast,add_prefix_space
|
2 |
+
flax-community/roberta-hindi,,,TRUE,TRUE,TRUE
|
3 |
+
mrm8488/HindiBERTa,,,FALSE,TRUE,TRUE
|
4 |
+
ai4bharat/indic-bert,,,FALSE,FALSE,FALSE
|
5 |
+
neuralspace-reverie/indic-transformers-hi-bert,,,FALSE,TRUE,TRUE
|
6 |
+
surajp/RoBERTa-hindi-guj-san,,,FALSE,TRUE,TRUE
|
mlm_custom/test_mlm.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from transformers import AutoTokenizer, RobertaModel, AutoModel, AutoModelForMaskedLM
|
4 |
+
from transformers import pipeline
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
|
8 |
+
|
9 |
+
class MLMTest():
|
10 |
+
|
11 |
+
def __init__(self, config_file="mlm_test_config.csv", full_text_file="mlm_full_text.csv", targeted_text_file="mlm_targeted_text.csv"):
|
12 |
+
|
13 |
+
self.config_df = pd.read_csv(os.path.join(os.path.dirname(os.path.realpath(__file__)), config_file))
|
14 |
+
self.config_df.fillna("", inplace=True)
|
15 |
+
self.full_text_df = pd.read_csv(os.path.join(os.path.dirname(os.path.realpath(__file__)), full_text_file))
|
16 |
+
self.targeted_text_df = pd.read_csv(os.path.join(os.path.dirname(os.path.realpath(__file__)), targeted_text_file))
|
17 |
+
self.full_text_results = []
|
18 |
+
self.targeted_text_results = []
|
19 |
+
|
20 |
+
def _run_full_test_row(self, text, print_debug=False):
|
21 |
+
return_data = []
|
22 |
+
data = text.split()
|
23 |
+
for i in range(0, len(data)):
|
24 |
+
masked_text = " ".join(data[:i]) + " "+self.nlp.tokenizer.mask_token+" " + " ".join(data[i+1:])
|
25 |
+
expected_result = data[i]
|
26 |
+
result = self.nlp(masked_text)
|
27 |
+
self.full_text_results.append({"text": masked_text, "result": result[0]["token_str"], "true_output": expected_result})
|
28 |
+
if print_debug:
|
29 |
+
print(masked_text)
|
30 |
+
print([x["token_str"] for x in result])
|
31 |
+
print("-"*20)
|
32 |
+
return_data.append({"prediction": result[0]["token_str"], "true_output": expected_result})
|
33 |
+
return return_data
|
34 |
+
|
35 |
+
def _run_targeted_test_row(self, text, expected_result, print_debug=False):
|
36 |
+
return_data = []
|
37 |
+
result = self.nlp(text.replace("<mask>", self.nlp.tokenizer.mask_token))
|
38 |
+
self.targeted_text_results.append({"text": text, "result": result[0]["token_str"], "true_output": expected_result})
|
39 |
+
if print_debug:
|
40 |
+
print(text)
|
41 |
+
print([x["token_str"] for x in result])
|
42 |
+
print("-"*20)
|
43 |
+
return_data.append({"prediction": result[0]["token_str"], "true_output": expected_result})
|
44 |
+
return return_data
|
45 |
+
|
46 |
+
def _compute_acc(self, results):
|
47 |
+
ctr = 0
|
48 |
+
for row in results:
|
49 |
+
try:
|
50 |
+
z = json.loads(row["true_output"])
|
51 |
+
if isinstance(z, list):
|
52 |
+
if row["prediction"] in z:
|
53 |
+
ctr+=1
|
54 |
+
except:
|
55 |
+
if row["prediction"] == row["true_output"]:
|
56 |
+
ctr+=1
|
57 |
+
|
58 |
+
return float(ctr/len(results))
|
59 |
+
|
60 |
+
def run_full_test(self, exclude_user_ids=[], print_debug=False):
|
61 |
+
df = pd.DataFrame()
|
62 |
+
for idx, row in self.config_df.iterrows():
|
63 |
+
self.full_text_results = []
|
64 |
+
|
65 |
+
model_name = row["model_name"]
|
66 |
+
display_name = row["display_name"] if row["display_name"] else row["model_name"]
|
67 |
+
revision = row["revision"] if row["revision"] else "main"
|
68 |
+
from_flax = row["from_flax"]
|
69 |
+
if from_flax:
|
70 |
+
model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True, revision=revision)
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
72 |
+
tokenizer.save_pretrained('exported_pytorch_model')
|
73 |
+
model.save_pretrained('exported_pytorch_model')
|
74 |
+
self.nlp = pipeline('fill-mask', model="exported_pytorch_model")
|
75 |
+
else:
|
76 |
+
self.nlp = pipeline('fill-mask', model=model_name)
|
77 |
+
accs = []
|
78 |
+
try:
|
79 |
+
for idx, row in self.full_text_df.iterrows():
|
80 |
+
if row["user_id"] in exclude_user_ids:
|
81 |
+
continue
|
82 |
+
results = self._run_full_test_row(row["text"], print_debug=print_debug)
|
83 |
+
|
84 |
+
acc = self._compute_acc(results)
|
85 |
+
accs.append(acc)
|
86 |
+
except:
|
87 |
+
print("Error for", display_name)
|
88 |
+
continue
|
89 |
+
|
90 |
+
print(display_name, " Average acc:", sum(accs)/len(accs))
|
91 |
+
if df.empty:
|
92 |
+
df = pd.DataFrame(self.full_text_results)
|
93 |
+
df.rename(columns={"result": display_name}, inplace=True)
|
94 |
+
else:
|
95 |
+
preds = [x["result"] for x in self.full_text_results]
|
96 |
+
df[display_name] = preds
|
97 |
+
df.to_csv("full_text_results.csv", index=False)
|
98 |
+
print("Results saved to full_text_results.csv")
|
99 |
+
|
100 |
+
def run_targeted_test(self, exclude_user_ids=[], print_debug=False):
|
101 |
+
|
102 |
+
df = pd.DataFrame()
|
103 |
+
for idx, row in self.config_df.iterrows():
|
104 |
+
self.targeted_text_results = []
|
105 |
+
|
106 |
+
model_name = row["model_name"]
|
107 |
+
display_name = row["display_name"] if row["display_name"] else row["model_name"]
|
108 |
+
revision = row["revision"] if row["revision"] else "main"
|
109 |
+
from_flax = row["from_flax"]
|
110 |
+
if from_flax:
|
111 |
+
model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True, revision=revision)
|
112 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
113 |
+
tokenizer.save_pretrained('exported_pytorch_model')
|
114 |
+
model.save_pretrained('exported_pytorch_model')
|
115 |
+
self.nlp = pipeline('fill-mask', model="exported_pytorch_model")
|
116 |
+
else:
|
117 |
+
self.nlp = pipeline('fill-mask', model=model_name)
|
118 |
+
accs = []
|
119 |
+
try:
|
120 |
+
for idx, row2 in self.targeted_text_df.iterrows():
|
121 |
+
if row2["user_id"] in exclude_user_ids:
|
122 |
+
continue
|
123 |
+
results = self._run_targeted_test_row(row2["text"], row2["output"], print_debug=print_debug)
|
124 |
+
|
125 |
+
acc = self._compute_acc(results)
|
126 |
+
accs.append(acc)
|
127 |
+
except:
|
128 |
+
import traceback
|
129 |
+
print(traceback.format_exc())
|
130 |
+
print("Error for", display_name)
|
131 |
+
continue
|
132 |
+
|
133 |
+
print(display_name, " Average acc:", sum(accs)/len(accs))
|
134 |
+
if df.empty:
|
135 |
+
df = pd.DataFrame(self.targeted_text_results)
|
136 |
+
df.rename(columns={"result": display_name}, inplace=True)
|
137 |
+
else:
|
138 |
+
preds = [x["result"] for x in self.targeted_text_results]
|
139 |
+
df[display_name] = preds
|
140 |
+
df.to_csv("targeted_text_results.csv", index=False)
|
141 |
+
print("Results saved to targeted_text_results.csv")
|
142 |
+
|
multiapp.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Frameworks for running multiple Streamlit applications as a single app.
|
2 |
+
"""
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
class MultiApp:
|
6 |
+
def __init__(self):
|
7 |
+
self.apps = []
|
8 |
+
|
9 |
+
def add_app(self, title, func):
|
10 |
+
self.apps.append({
|
11 |
+
"title": title,
|
12 |
+
"function": func
|
13 |
+
})
|
14 |
+
|
15 |
+
def run(self):
|
16 |
+
st.sidebar.header('Navigation')
|
17 |
+
app = st.sidebar.radio(
|
18 |
+
'',
|
19 |
+
self.apps,
|
20 |
+
format_func=lambda app: app['title'])
|
21 |
+
|
22 |
+
app['function']()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
jax
|