import streamlit as st
import transformers
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pandas as pd
import string
from time import time
from PIL import Image


image = Image.open('./Image_AraThon.PNG')
n_image  = image.resize((150, 150))
st.image(n_image)

st.title("المساعدة اللغوية في التنبؤ بالمتلازمات والمتصاحبات وتصحيحها")
default_value = "أستاذ التعليم"

# sent is the variable holding the user's input
sent = st.text_area('المدخل',default_value)

tokenizer = AutoTokenizer.from_pretrained("moussaKam/AraBART", max_length=128, padding=True, pad_to_max_length = True, truncation=True)

model = AutoModelForMaskedLM.from_pretrained("Hamda/test-1-finetuned-AraBART")
pipe = pipeline("fill-mask", tokenizer=tokenizer, model=model, top_k=10)

def next_word(text, pipe):
    res_dict= {  
      'الكلمة المقترحة':[],
      'العلامة':[],
    }
    for e in pipe(text):
        if all(c not in list(string.punctuation) for c in e['token_str']):
            res_dict['الكلمة المقترحة'].append(e['token_str'])
            res_dict['العلامة'].append(e['score'])
    return res_dict

if (st.button('بحث', disabled=False)):
    text_st = sent+ ' <mask>'
    dict_next_words = next_word(text_st, pipe)
    df = pd.DataFrame.from_dict(dict_next_words)
    st.dataframe(df)
#using Graph   

if (st.checkbox('الاستعانة بالرسم البياني المعرفي الاحتمالي', value=False)):
    a = time()
    VocMap = './voc.csv'
    ScoreMap = './BM25.csv'
    
    #@st.cache
    def reading_df(path1, path2):
        df_voc = pd.read_csv(path1, delimiter='\t')
        df_graph = pd.read_csv(path2, delimiter='\t')
        df_graph.set_index(['ID1','ID2'], inplace=True)
        df_gr = pd.read_csv(ScoreMap, delimiter='\t')
        df_gr.set_index(['ID1'], inplace=True)
        return df_voc, df_graph, df_gr
        
    df3, df_g, df_in = reading_df(VocMap, ScoreMap)
    

    def Query2id(voc, query):
        res= [] 
        for word in query.split():
            try:
                res.append(voc.index[voc['word'] == word].values[0])
            except (IndexError, KeyError) as e:
                st.write('Token not found')
                continue
        return res
    
    id_list = Query2id(df3, sent)

    def setQueriesVoc(df, id_list):
        res = []
        for e in id_list:
            try:
                res.extend(list(df.loc[e]['ID2'].values)) 
            except (KeyError, AttributeError) as f:
                st.write('Token not found')
                continue
        return list(set(res))
    
    L = setQueriesVoc(df_in, id_list)
    @st.cache
    def compute_score(L_terms, id_l):
        tmt = {}
        for nc in L_terms:
            score = 0.0
            temp = []
            for ni in id_l:
                try:
                    score = score + df_g.loc[(ni, nc),'score']
                except KeyError:
                    continue
            key  = df3.loc[nc].values[0]
            tmt[key] = score
        return tmt
    tmt = compute_score(L, id_list)    
    exp_terms = []
    t_li = tmt.values()
    tmexp = sorted(tmt.items(), key=lambda x: x[1], reverse=True)
    i = 0
    dict_res = {'الكلمة المقترحة':[], 
    'العلامة':[]}
    for key, value in tmexp:
        new_score=((value-min(t_li))/(max(t_li)-min(t_li)))-0.0001
        dict_res['العلامة'].append(str(new_score)[:6])
        dict_res['الكلمة المقترحة'].append(key)
        i+=1
        if (i==10):
            break
    res_df = pd.DataFrame.from_dict(dict_res)
    res_df.index += 1
    b = time()
    exec_time = (b-a)
    text_st = sent+ ' <mask>'
    dict_next_words = next_word(text_st, pipe)
    df = pd.DataFrame.from_dict(dict_next_words)
    df.index += 1
    str_time = str(exec_time)[:3]
    
    st.markdown("""---""")
    st.header("الكلمات المقترحة باستعمال النموذج اللغوي")
    st.dataframe(df)
    st.markdown("""---""")
    st.header("الكلمات المقترحة باستعمال الرسم البياني")
    st.dataframe(res_df)
    st.markdown("""---""")
    st.write(f'{str_time} s :الوقت المستغرق باستعمال الرسم البياني')