import gradio as gr
from transformers import pipeline
import pandas as pd
import json
import nltk
from sentence_transformers import SentenceTransformer, util
import numpy as np
from LexRank import *
from text import *

nltk.download('punkt')


def lex_rank(in_text, threshold=None , ex_sent=4 ,model_in = 'KBLab/sentence-bert-swedish-cased', language='swedish' ):
    if threshold == 'None':
      threshold=None

    model = SentenceTransformer(model_in)
    #Split the document into sentences
    sentences = nltk.sent_tokenize(in_text, language=language)
    
    #Compute the sentence embeddings
    embeddings = model.encode(sentences, convert_to_tensor=True)
    cos_scores = util.cos_sim(embeddings, embeddings).cpu().numpy()

    #Compute the centrality for each sentence
    centrality_scores = degree_centrality_scores(cos_scores, threshold=threshold)

    most_central_sentence_indices = np.argsort(-centrality_scores)
    sent_list= []
    for idx in most_central_sentence_indices[0:ex_sent]:
        sent_list.append(sentences[idx])
    return ' '.join(sent_list)


def generate(in_text, num_beams, min_len, max_len, model_in):
  print(in_text)
  pipe = pipeline("summarization", model=model_in)
  answer = pipe(in_text, num_beams=num_beams ,min_length=min_len, max_length=max_len)
  print(answer)
  return answer[0]["summary_text"]

  
def update_history(df, in_text, gen_text ,model_in, sum_typ, parameters):
    # get rid of first seed phrase
    new_row = [{"In_text": in_text,
                "Gen_text": gen_text,
                "Sum_type": sum_typ ,
                "Gen_model": model_in,
                "Parameters": json.dumps(parameters)}]
    return pd.concat([df, pd.DataFrame(new_row)]) 
    
def generate_transformer(in_text, num_beams, min_len, max_len, model_in, history):
    gen_text= generate(in_text,num_beams, min_len, max_len, model_in)
    return gen_text, update_history(history, in_text, gen_text, "Abstractive" ,model_in, {"num_beams": num_beams, 
                                                                                          "min_len": min_len,
                                                                                          "max_len": max_len})

def generate_lexrank(in_text, threshold, model_in, ex_sent ,language, history):
    gen_text= lex_rank(in_text, threshold, ex_sent ,model_in, language)
    return gen_text, update_history(history, in_text, gen_text, "Extractive" ,model_in, {"threshold": threshold, 
                                                                                         "Nr_sent": ex_sent,
                                                                                         "language": language})

with gr.Blocks() as demo:
    gr.Markdown("<h1><center> Swedish Summarization Engine! </center></h1>")
    with gr.Accordion("Read here for details about the app", open=False):
        with gr.Row(): 
            with gr.Column(css=".gr-prose img {margin-bottom: 0em !important;}"):
                    gr.Markdown(sum_app_text_tab_1)
            with gr.Column(css=".gr-prose img {margin-bottom: 0em !important;}"):   
                    gr.Markdown(sum_app_text_tab_2)

    with gr.Tabs():
        with gr.TabItem("Abstractive Generation for Summarization"):
            gr.Markdown(
                      """The default parameters for this transformer based model work well to generate summarization.
                         Use this tab to experiment summarization task of text for different types Abstractive models.""")
            with gr.Row():
                with gr.Column(scale=4):
                    text_baseline_transformer= gr.TextArea(label="Input text to summarize", placeholder="Input summarization")
                    
                    with gr.Row():
                        transformer_button_clear = gr.Button("Clear", variant='secondary')
                        transformer_button = gr.Button("Summarize!", variant='primary')
                    
                with gr.Column(scale=3):
                    with gr.Row():
                        num_beams = gr.Slider(minimum=2, maximum=10, value=2, step=1, label="Number of Beams")
                        min_len = gr.Slider(minimum=10, maximum=50, value=25, step=5, label="Min length")
                        max_len = gr.Slider(minimum=50, maximum=130, value=120, step=10, label="Max length")
                    model_in = gr.Dropdown(["Gabriel/bart-base-cnn-swe", "Gabriel/bart-base-cnn-xsum-swe", "Gabriel/bart-base-cnn-xsum-wiki-swe"], value="Gabriel/bart-base-cnn-xsum-swe", label="Model")
                    output_basline_transformer = gr.Textbox(label="Output Text")

            with gr.Row():
                with gr.Accordion("Here are some examples you can use:", open=False): 
                    gr.Markdown("<h3>Press one of the test examples below.<h3>")
                    gr.Markdown("NOTE: First time inference for a new model will take time, since a new model has to downloaded before inference.")
                    gr.Examples([[abstractive_example_text_1
                    , 5,25,120, "Gabriel/bart-base-cnn-swe"],
                    [abstractive_example_text_2
                    , 5,25,120, "Gabriel/bart-base-cnn-xsum-swe"]
                    ], [text_baseline_transformer, num_beams, min_len, max_len, model_in])       

        with gr.TabItem("Extractive Ranking Graph for Summarization"):
            gr.Markdown(
                      """Use this tab to experiment summarization task of text with a graph based method (LexRank).""")
            with gr.Row():
                with gr.Column(scale=4):
                    text_extract= gr.TextArea(label="Input text to summarize", placeholder="Input text")
                    with gr.Row():
                        extract_button_clear = gr.Button("Clear", variant='secondary')
                        extract_button = gr.Button("Summarize!", variant='primary')
                with gr.Column(scale=3):
                    with gr.Row():
                        ex_sent =gr.Slider(minimum=1, maximum=7, value=4, step=1, label="Sentences to return")
                        ex_threshold = gr.Dropdown(['None',0.1,0.2,0.3,0.4,0.5], value='None', label="Similar Threshold")
                        ex_language = gr.Dropdown(["swedish","english"], value="swedish", label="Language")
                    model_in_ex = gr.Dropdown(["KBLab/sentence-bert-swedish-cased","sentence-transformers/all-MiniLM-L6-v2"], value="KBLab/sentence-bert-swedish-cased", label="Model")
                    output_extract = gr.Textbox(label="Output Text")

            with gr.Row():
              with gr.Accordion("Here are some examples you can use:", open=False): 
                  gr.Markdown("<h3>Press one of the test examples below.<h3>")
                  gr.Markdown("NOTE: First time inference for a new model will take time, since a new model has to downloaded before inference.")
                  gr.Examples([[extractive_example_text_1
                  , 'None', 4,'swedish', "KBLab/sentence-bert-swedish-cased"]], [text_extract, ex_threshold, ex_sent ,ex_language, model_in_ex])     

    with gr.Box():
        gr.Markdown("<h3> Generation History <h3>")
        # Displays a dataframe with the history of moves generated, with parameters
        history = gr.Dataframe(headers=["In_text", "Gen_text","Sum_type" ,"Gen_model", "Parameters"], overflow_row_behaviour="show_ends", wrap=True)

    transformer_button.click(generate_transformer, inputs=[text_baseline_transformer, num_beams, min_len, max_len, model_in ,history], outputs=[output_basline_transformer , history], api_name="summarize" )
    extract_button.click(generate_lexrank, inputs=[text_extract, ex_threshold, model_in_ex, ex_sent ,ex_language ,history], outputs=[output_extract , history] )


demo.launch()