AraJARIR / app.py
Hamda's picture
Update app.py
05a2e1d
raw
history blame
2.38 kB
import streamlit as st
import transformers
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pandas as pd
import string
st.title("المساعدة اللغوية في التنبؤ بالمتلازمات والمتصاحبات وتصحيحها")
default_value = "بيعت الأسلحة في السوق"
# sent is the variable holding the user's input
sent = st.text_area("مدخل", default_value, height=20)
st.checkbox('استعمال الرسم البياني', value=False)
tmt = {}
VocMap = r'/home/voc.csv'
ibra_gr = r'./BM25.csv'
df3 = pd.read_csv(VocMap, delimiter='\t')
df_g = pd.read_csv(ibra_gr, delimiter='\t')
df_g.set_index(['ID1','ID2'], inplace=True)
df_in = pd.read_csv(ibra_gr, delimiter='\t')
df_in.set_index(['ID1'], inplace=True)
def Query2id(voc, query):
return [voc.index[voc['word'] == word].values[0] for word in query.split()]
id_list = Query2id(df3, sent)
def setQueriesVoc(df, id_list):
res = []
for e in id_list:
res.extend(list(df.loc[e]['ID2'].values))
return list(set(res))
L = setQueriesVoc(df_in, id_list)
tokenizer = AutoTokenizer.from_pretrained("moussaKam/AraBART", max_length=128, padding=True, pad_to_max_length = True, truncation=True)
model = AutoModelForMaskedLM.from_pretrained("Hamda/test-1-finetuned-AraBART")
#@st.cache
def next_word(text, pipe):
res_dict= {
'Word':[],
'Score':[],
}
for e in pipe(text):
if all(c not in list(string.punctuation) for c in e['token_str']):
res_dict['Word'].append(e['token_str'])
res_dict['Score'].append(e['score'])
return res_dict
text_st = sent+ ' <mask>'
pipe = pipeline("fill-mask", tokenizer=tokenizer, model=model, top_k=10)
dict_next_words = next_word(text_st, pipe)
df = pd.DataFrame.from_dict(dict_next_words)
df.reset_index(drop=True, inplace=True)
for nc in L:
score = 0.0
temp = []
for ni in id_list:
try:
score = score + df_g.loc[(ni, nc),'score']
except KeyError:
continue
key = df3.loc[nc].values[0]
tmt[key] = score
exp_terms = []
tmexp = sorted(tmt.items(), key=lambda x: x[1], reverse=True)
i = 0
for key, value in tmexp:
exp_terms.append(str(key)+' | '+str(value))
i+=1
if (i==10):
break
st.dataframe(df)
st.write(exp_terms)
#st.table(df)