# -*- coding: utf-8 -*-
"""
Created on Fri May 26 14:07:22 2023

@author: vibin
"""

import streamlit as st
from pandasql import sqldf
import pandas as pd
import re
from typing import List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import re


@st.cache_resource()
def tapas_model():
    return(pipeline(task="table-question-answering", model="google/tapas-base-finetuned-wtq"))

@st.cache_resource()
def prepare_input(question: str, table: List[str]):
    table_prefix = "table:"
    question_prefix = "question:"
    join_table = ",".join(table)
    inputs = f"{question_prefix} {question} {table_prefix} {join_table}"
    input_ids = tokenizer(inputs, max_length=512, return_tensors="pt").input_ids
    return input_ids

@st.cache_resource()
def inference(question: str, table: List[str]) -> str:
    input_data = prepare_input(question=question, table=table)
    input_data = input_data.to(model.device)
    outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700)
    result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
    return result

@st.cache_resource()
def tokmod(tok_md):
    tkn = AutoTokenizer.from_pretrained(tok_md)
    mdl = AutoModelForSeq2SeqLM.from_pretrained(tok_md)
    return(tkn,mdl)


### Main

nav = st.sidebar.radio("Navigation",["TAPAS","Text2SQL"])
if nav == "TAPAS":
    
    col1 , col2, col3 = st.columns(3)
    col2.title("TAPAS")
    
    col3 , col4 = st.columns([3,12])
    col4.text("Tabular Data Text Extraction using text")
    
    table = pd.read_csv("data.csv")
    table = table.astype(str)
    st.text("DataSet - ")
    st.dataframe(table,width=3000,height= 400)
    
    st.title("")
    
    lst_q = ["Which country has low medicare","Who are the patients from india","Who are the patients from india","Patients who have Edema","CUI code for diabetes patients","Patients having oxygen less than 94 but 91"]
   
    v2 = st.selectbox("Choose your text",lst_q,index = 0)

    st.title("")
    
    sql_txt = st.text_area("TAPAS Input",v2)
    
    if st.button("Predict"):   
        tqa = tapas_model()
        txt_sql = tqa(table=table, query=sql_txt)["answer"]
        st.text("Output - ") 
        st.success(f"{txt_sql}")
        # st.write(all_students)
    
    
    
elif nav == "Text2SQL":
    
    ### Function
    col1 , col2, col3 = st.columns(3)
    col2.title("Text2SQL")
    
    col3 , col4 = st.columns([1,20])
    col4.text("Text will be converted to SQL Query and can extract the data from DataSet")
    
    # Import Data
    
    df_qna = pd.read_csv("qnacsv.csv", encoding= 'unicode_escape')
    
    st.title("")
    
    st.text("DataSet - ")
    st.dataframe(df_qna,width=3000,height= 500)
    
    st.title("")
    
    lst_q = ["what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD", "get class code with measure = 72_HR_ABX", "get sum of version for Class_Code is Antibiotic Stewardship", "what interface is measure indicator code = 72_HR_ABX"]
    v2 = st.selectbox("Choose your text",lst_q,index = 0)

    st.title("")
    
    
    sql_txt = st.text_area("Text for SQL Conversion",v2)
    
    
    if st.button("Predict"):
        
        tok_model = "juierror/flan-t5-text2sql-with-schema"
        tokenizer,model = tokmod(tok_model)
        
        # text = "what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD"
        table_name = "df_qna"
        table_col = ["Type","Class_Code", "Version","Measure_Indicator_Code","Measure_Indicator_Name","Description_Definition", "Source", "Interfaces"]
        
        txt_sql = inference(question=sql_txt, table=table_col)
        
        
        ### SQL Modification
        sql_avg = ["AVG","COUNT","DISTINCT","MAX","MIN","SUM"]
        txt_sql = txt_sql.replace("table",table_name)
        sql_quotes = []
        for match in re.finditer("=",txt_sql):
            new_txt = txt_sql[match.span()[1]+1:]
            try:
                match2 = re.search("AND",new_txt)
                sql_quotes.append((new_txt[:match2.span()[0]]).strip())
            except:
                sql_quotes.append(new_txt.strip())
        
        for i in sql_quotes:
            qts = "'" + i + "'"
            txt_sql = txt_sql.replace(i, qts)
            
        for r in sql_avg:
            if r in txt_sql:
                rr = re.search(rf"{r} (\w+)", txt_sql)
                init = " " + rr[1]
                qts = "(" + rr[1] + ")"
                txt_sql = txt_sql.replace(init,qts)
            else:
                pass
            
            
        st.success(f"{txt_sql}")
        all_students = sqldf(txt_sql)
        
        st.text("Output - ")
        st.write(all_students)