import gradio as gr
import urllib
import base64
import random
import requests
import bs4
import lxml
import os
from huggingface_hub import InferenceClient,HfApi
import random
import json
import datetime
from pypdf import PdfReader
import uuid
from PIL import Image
from screenshot import create_ss

from agent import (
    PREFIX,
    GET_CHART,
    COMPRESS_DATA_PROMPT,
    COMPRESS_DATA_PROMPT_SMALL,
    LOG_PROMPT,
    LOG_RESPONSE,
)
api=HfApi()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

############  Document Functions #################
def find_all(url):
    return_list=[]
    print (url)
    #if action_input in query.tasks:
    print (f"trying URL:: {url}")        
    try:
        if url != "" and url != None:    
            out = []
            source = requests.get(url)
            #source = urllib.request.urlopen(url).read()
            soup = bs4.BeautifulSoup(source.content,'lxml')

            rawp=(f'RAW TEXT RETURNED: {soup.text}')
            cnt=0
            cnt+=len(rawp)
            out.append(rawp)
            out.append("HTML fragments: ")
            q=("a","p","span","content","article")
            for p in soup.find_all("a"):
                out.append([{"LINK TITLE":p.get('title'),"URL":p.get('href'),"STRING":p.string}])
  
            print(rawp)
            return True, rawp
        else: 
            return False, "Enter Valid URL"
    except Exception as e:
        print (e)
        return False, f'Error: {e}'
    return "MAIN", None, history, task

def read_txt(txt_path):
    text=""
    with open(txt_path,"r") as f:
        text = f.read()
    f.close()
    print (text)
    return text

def read_pdf(pdf_path):
    text=""
    reader = PdfReader(f'{pdf_path}')
    number_of_pages = len(reader.pages)
    for i in range(number_of_pages):
        page = reader.pages[i]
        text = f'{text}\n{page.extract_text()}'
    print (text)
    return text

error_box=[]
def read_pdf_online(url):
    uid=uuid.uuid4()
    print(f"reading {url}")
    response = requests.get(url, stream=True)
    print(response.status_code)
    text=""
    try:
        if response.status_code == 200:
            with open("test.pdf", "wb") as f:
                f.write(response.content)
            #f.close()
            #out = Path("./data.pdf")
            #print (out)
            reader = PdfReader("test.pdf")
            number_of_pages = len(reader.pages)
            print(number_of_pages)
            for i in range(number_of_pages):
                page = reader.pages[i]
                text = f'{text}\n{page.extract_text()}'
                print(f"PDF_TEXT:: {text}")
            return text
        else:
            text = response.status_code
            error_box.append(url)
            print(text)
            return text


    except Exception as e:
        print (e)
        return e


VERBOSE = True
MAX_HISTORY = 100
MAX_DATA = 20000

def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def run_gpt_no_prefix(
    prompt_template,
    stop_tokens,
    max_tokens,
    seed,
    **prompt_kwargs,
):
    print(seed)
    try:
        generate_kwargs = dict(
            temperature=0.9,
            max_new_tokens=max_tokens,
            top_p=0.95,
            repetition_penalty=1.0,
            do_sample=True,
            seed=seed,
        )
        
        content = prompt_template.format(**prompt_kwargs)
        #if VERBOSE:
        print(LOG_PROMPT.format(content))
        
        
        #formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
        #formatted_prompt = format_prompt(f'{content}', history)
    
        stream = client.text_generation(content, **generate_kwargs, stream=True, details=True, return_full_text=False)
        resp = ""
        for response in stream:
            resp += response.token.text
            #yield resp
    
        #if VERBOSE:
        print(LOG_RESPONSE.format(resp))
        return resp
    except Exception as e:
        print(f'no_prefix_error:: {e}')
        return "Error"
def run_gpt(
    prompt_template,
    stop_tokens,
    max_tokens,
    seed,
    **prompt_kwargs,
):
    print(seed)
    timestamp=datetime.datetime.now()
    
    generate_kwargs = dict(
        temperature=0.9,
        max_new_tokens=max_tokens,
        top_p=0.95,
        repetition_penalty=1.0,
        do_sample=True,
        seed=seed,
    )
    
    content = PREFIX.format(
        timestamp=timestamp,
        purpose="Compile the provided data and complete the users task"
    ) + prompt_template.format(**prompt_kwargs)
    #if VERBOSE:
    print(LOG_PROMPT.format(content))
    
    
    #formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    #formatted_prompt = format_prompt(f'{content}', history)

    stream = client.text_generation(content, **generate_kwargs, stream=True, details=True, return_full_text=False)
    resp = ""
    for response in stream:
        resp += response.token.text
        #yield resp

    if VERBOSE:
        print(LOG_RESPONSE.format(resp))
    return resp

    
def compress_data(c, instruct, history, seed):
    #seed=random.randint(1,1000000000)
    
    print (c)
    #tot=len(purpose)
    #print(tot)
    divr=int(c)/MAX_DATA
    divi=int(divr)+1 if divr != int(divr) else int(divr)
    chunk = int(int(c)/divr)
    print(f'chunk:: {chunk}')
    print(f'divr:: {divr}')
    print (f'divi:: {divi}')
    out = []
    #out=""
    s=0
    e=chunk
    print(f'e:: {e}')
    new_history=""
    #task = f'Compile this data to fulfill the task: {task}, and complete the purpose: {purpose}\n'
    for z in range(divi):
        print(f's:e :: {s}:{e}')
        
        hist = history[s:e]
        
        resp = run_gpt(
            COMPRESS_DATA_PROMPT_SMALL,
            stop_tokens=["observation:", "task:", "action:", "thought:"],
            max_tokens=8192,
            seed=seed,
            direction=instruct,
            knowledge="",
            history=hist,
        ).strip("\n")
        out.append(resp)
        #new_history = resp
        print (resp)
        #out+=resp
        e=e+chunk
        s=s+chunk
    return out

  
def get_chart(inp,seed):
    #seed=random.randint(1,1000000000)
    try:
        resp = run_gpt_no_prefix(
            GET_CHART,
            stop_tokens=[],
            max_tokens=8192,
            seed=seed,
            inp=inp,
        ).strip("\n")
        print(resp)
    except Exception as e:
        print(f'Error:: {e}')
        resp = e
    return resp

def format_json(inp):

    print("FORMATTING:::")
    print(type(inp))
    print("###########")
    print(inp)
    print("###########")
    print("###########")
    new_str=""
    matches=["```","#","//"]
    for i,line in enumerate(inp):
        line = line.strip()
        print(line)
        #if not any(x in line for x in matches):
        new_str+=line.strip("\n").strip("```").strip("#").strip("//")
    print("###########")
    print("###########")
    #inp = inp.strip("<\s>")
    new_str=new_str.strip("</s>")
    out_json=eval(new_str)
    print(out_json)
    print("###########")
    print("###########")
    
    return out_json


this=["1.25"]
css="""
#wrap { width: 100%; height: 100%; padding: 0; overflow: auto; }
#frame { width: 100%; border: 1px solid black; }
#frame { zoom: $ZOOM; -moz-transform: scale($ZOOM); -moz-transform-origin: 0 0; }
"""


def mm(graph,zoom):
    code_out=""
    for ea in graph.split("\n"):
        code=ea.strip().strip("\n")
        code_out+=code
    #out_html=f'''<div><iframe src="https://omnibus-mermaid-script.static.hf.space/index.html?mermaid={code_out}&rand={random.randint(1,1111111111)}" height="500" width="500"></iframe></div>'''
    
        #url=f"https://omnibus-mermaid-script.static.hf.space/index.html?mermaid={code_out}"
    url=f"https://omnibus-mermaid-script.static.hf.space/index.html?mermaid={urllib.parse.quote_plus(code_out)}"
    out_html=f'''<div id="wrap" style="width: 100%; height: 100%;max-height:600px; padding: 0; overflow: auto;"><iframe id="frame" src="{url}" style="width:100%; height:600px; border: 1px solid black; zoom: {str(zoom)}; -moz-transform: scale({str(zoom)}); -moz-transform-origin: 0 0;" allow="accelerometer; ambient-light-sensor; autoplay; battery; camera; document-domain; encrypted-media; fullscreen; geolocation; gyroscope; layout-animations; legacy-image-formats; magnetometer; microphone; midi; oversized-images; payment; picture-in-picture; publickey-credentials-get; sync-xhr; usb; vr ; wake-lock; xr-spatial-tracking" sandbox="allow-forms allow-modals allow-popups allow-popups-to-escape-sandbox allow-same-origin allow-scripts allow-downloads"></iframe></div>'''
    return out_html,url
    
def summarize(inp,history,seed,data=None,files=None,directory=None,url=None,pdf_url=None,pdf_batch=None):
    json_box=[]
    chart_out=""
    if inp == "":
        inp = "Process this data"
    history.clear()
    history = [(inp,"Working on it...")] 
    yield "",history,chart_out,chart_out,json_box,""

    if pdf_batch.startswith("http"):
        lab="PDF Batch"
        c=0
        data=""
        for i in str(pdf_batch):
            if i==",":
                c+=1
        print (f'c:: {c}')

        try:
            for i in range(c+1):
                batch_url = pdf_batch.split(",",c)[i]
                bb = read_pdf_online(batch_url)
                data=f'{data}\nFile Name URL ({batch_url}):\n{bb}'
        except Exception as e:
            print(e)
            #data=f'{data}\nError reading URL ({batch_url})'
            
    if directory:
        lab="Directory"
        
        for ea in directory:
            print(ea)
        
    if pdf_url.startswith("http"):
        lab="PDF URL"
        
        print("PDF_URL")
        out = read_pdf_online(pdf_url)
        data=out
    if url.startswith("http"):
        lab="Raw HTML"
        
        val, out = find_all(url)
        if not val:
            data="Error"
            rawp = str(out)
        else:
            data=out
    if files:
        lab="Files"
        
        for i, file in enumerate(files):
            try: 
                print (file)
                if file.endswith(".pdf"):
                    zz=read_pdf(file)
                    print (zz)
                    data=f'{data}\nFile Name ({file}):\n{zz}'
                elif file.endswith(".txt"):
                    zz=read_txt(file)
                    print (zz)
                    data=f'{data}\nFile Name ({file}):\n{zz}'                
            except Exception as e:
                data=f'{data}\nError opening File Name ({file})'                
                print (e)

    
    if data != "Error" and data != "":
        history.clear()
        history = [(inp,f"Data: Loaded, processing...")] 
        yield "",history,chart_out,chart_out,json_box,""
        
        print(inp)
        out = str(data)
        rl = len(out)
        print(f'rl:: {rl}')
        c=1
        for i in str(out):
            if i == " " or i=="," or i=="\n":
                c +=1
        print (f'c:: {c}')
        json_out = compress_data(c,inp,out,seed)  
        out = str(json_out)
        try:
            json_out=format_json(json_out)
        except Exception as e:
            print (e)
        history.clear()
        history = [(inp,"Building Chart...")] 
        yield "",history,chart_out,chart_out,json_out,""

            
        chart_out = get_chart(str(json_out),seed)
        chart_list=chart_out.split("\n")
        go=True
        cnti=1
        line_out=""
        for ii, line in enumerate(chart_list):
            if go:
                line=line.strip().replace('"',"")
                if "```" in chart_list[ii]:
                    while True:
                        line_out+=chart_list[ii+cnti].strip().replace("\n"," ").replace('"',"").replace("/"," ").replace("."," ").replace(":"," ").replace("#","")
                        if not line_out.strip().endswith(";"):
                            line_out+=";"
                        line_out+="\n"
                        cnti+=1
                        if "```" in chart_list[ii+cnti]:
                            go=False
                            break
                    
        
        chart_html,chart_url=mm(line_out,1)
        #print(chart_out)
    else:
        rawp = "Provide a valid data source"
    history.clear()
    history.append((inp,chart_out))
    yield "", history,chart_html,line_out,json_out,chart_url

#################################
def clear_fn():
    return "",[(None,None)]

def create_image(url):
    print(url)
    with open("tmp.svg","w") as svg:
        svg.write(url)
        
    with open("tmp.svg", "rb") as f:
        encoded_image = base64.b64encode(f.read())        
    this = Image.open("tmp.svg")
    out_im = this.save("tmp.png")    
    with open("image.png","wb") as file:
        #file.write(eval(encoded_image))
        file.write(encoded_image)
    #output = cairosvg.svg2png(
    #    bytestring=open('tmp.svg').read().encode('utf-8'), write_to="output.png")        
    return "tmp.png"


score_js="""
function(text_input) {
    console.log(text_input);
    const iframe = document.getElementById("frame").contentWindow.document.getElementById('chart').innerHTML;
    console.log(iframe);
    return [iframe];
}
"""
def zoom_update(inp):
    this.clear()
    this.append(str(inp))
    return gr.update()

with gr.Blocks() as app:
    gr.HTML("""<center><h1>Text -to- Chart</h1><h3>Mixtral 8x7B</h3>""")
    chatbot = gr.Chatbot(label="Mixtral 8x7B Chatbot",show_copy_button=True)
    with gr.Row():
        with gr.Column(scale=3):
            prompt=gr.Textbox(label = "Instructions (optional)")
        with gr.Column(scale=1):
            
            button=gr.Button()
        
        #models_dd=gr.Dropdown(choices=[m for m in return_list],interactive=True)
    with gr.Row():
        stop_button=gr.Button("Stop")
        clear_btn = gr.Button("Clear")
    with gr.Row():
        with gr.Tab("Text"):
            data=gr.Textbox(label="Input Data (paste text)", lines=6)
        with gr.Tab("File"):
            file=gr.Files(label="Input File(s) (.pdf .txt)")
        with gr.Tab("Folder"):
            directory=gr.File(label="Folder", file_count='directory')            
        with gr.Tab("Raw HTML"):
            url = gr.Textbox(label="URL")
        with gr.Tab("PDF URL"):
            pdf_url = gr.Textbox(label="PDF URL")       
        with gr.Tab("PDF Batch"):
            pdf_batch = gr.Textbox(label="PDF URL Batch (comma separated)")
    m_box=gr.HTML()
    e_box=gr.Textbox(label="Graph Code",interactive=True)
    with gr.Row():
        upd_button=gr.Button("Update Chart")
        create_im=gr.Button("Create Image")
    with gr.Row():
        with gr.Column(scale=3):
            svg_img=gr.Image()
        with gr.Column(scale=1):
            wid=gr.Number(label="Width",value=1000)
            hgt=gr.Number(label="Height",value=4000)
            seed_slider=gr.Slider(label="Seed",step=1,minimum=1,maximum=9999999999999999999,value=1,interactive=True)
            zoom_btn=gr.Slider(label="Zoom",step=0.01,minimum=0.1,maximum=20,value=1,interactive=True)        
    url_box=gr.Textbox(label="Graph URL",interactive=True)

    json_out=gr.JSON()
    #text=gr.JSON()

    #get_score.click(return_score,score,[score],_js=score_js)

    score=gr.Textbox()
    def return_score(text):
        print(text)
        return text
    create_im.click(create_ss,[e_box,wid,hgt],svg_img)
    #create_im.click(return_score,score,[score],_js=score_js).then(create_image,score,svg_img)
    #zoom_btn.change(zoom_update,zoom_btn,None)
    upd_button.click(mm,[e_box,zoom_btn],[m_box,url_box])
    #inp_query.change(search_models,inp_query,models_dd)
    clear_btn.click(clear_fn,None,[prompt,chatbot])
    
    #go=button.click(summarize,[prompt,chatbot,report_check,chart_check,data,file,directory,url,pdf_url,pdf_batch],[prompt,chatbot,e_box,json_out])
    go=button.click(summarize,[prompt,chatbot,seed_slider,data,file,directory,url,pdf_url,pdf_batch],[prompt,chatbot,m_box,e_box,json_out,url_box])
    
    stop_button.click(None,None,None,cancels=[go])
#app.queue(default_concurrency_limit=20).launch(show_api=False) 
app.queue(default_concurrency_limit=20).launch(show_api=False)