import gradio as gr
import regex as re
import torch
import nltk
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from nltk.tokenize import sent_tokenize
import plotly.express as px
import time
import tqdm
nltk.download('punkt_tab')

# Define the model and tokenizer
checkpoint = "sadickam/sdg-classification-bert"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)

# Define the function for preprocessing text
def prep_text(text):
    clean_sents = []
    sent_tokens = sent_tokenize(str(text))
    for sent_token in sent_tokens:
        word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()]
        clean_sents.append(' '.join((word_tokens)))
    joined = ' '.join(clean_sents).strip(' ')
    joined = re.sub(r'`', "", joined)
    joined = re.sub(r'"', "", joined)
    return joined

# APP INFO
def app_info():
    check = """
    Please go to either the "Single-Text-Prediction" or "Multi-Text-Prediction" tab to analyse your text. 
    """
    
    return check

# Create Gradio interface for single text
iface1 = gr.Interface(
    fn=app_info, inputs=None, outputs=['text'], title="General-Infomation",
    description= '''
    This app, powered by the sgdBERT model (sadickam/sdg-classification-bert), is for automatic classification of text concerning 
    the UN Sustainable Development Goals (SDG). Note that 16 out of the 17 SDGs labels are covered. This app is for sustainability 
    assessment and benchmarking and is not limited to a specific industry. The model powering this app was developed using the 
    OSDG Community Dataset (OSDG-CD) [Link - https://zenodo.org/record/5550238#.Y8Sd5f5ByF5].
    
    This app has two analysis modules summarised below:
    - Single-Text-Prediction - Analyses text pasted in a text box and return SDG prediction.
    - Multi-Text-Prediction - Analyses multiple rows of texts in an uploaded CSV file and returns a downloadable CSV file with SDG prediction for each row of text.
    
    This app runs on a free server and may therefore not be suitable for analysing large CSV and PDF files. 
    If you need assistance with analysing large CSV or PDF files, do get in touch using the contact information in the Contact section.
    
    <h3>Contact</h3>
    <p>We would be happy to receive your feedback regarding this app. If you would also like to collaborate with us to explore some use cases for the model 
    powering this app, we are happy to hear from you.</p>


    ''')

# SINGLE TEXT
# Define the prediction function
def predict_sdg(text):
    # Preprocess the input text
    cleaned_text = prep_text(text)
    if cleaned_text == "":
        raise gr.Error('This model needs some text input to return a prediction')
    elif cleaned_text != "":
        # Tokenize the preprocessed text
        tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
        # Predict
        text_logits = model(**tokenized_text).logits
        predictions = torch.softmax(text_logits, dim=1).tolist()[0]
        # SDG labels
        label_list = [
            'GOAL 1: No Poverty',
            'GOAL 2: Zero Hunger',
            'GOAL 3: Good Health and Well-being',
            'GOAL 4: Quality Education',
            'GOAL 5: Gender Equality',
            'GOAL 6: Clean Water and Sanitation',
            'GOAL 7: Affordable and Clean Energy',
            'GOAL 8: Decent Work and Economic Growth',
            'GOAL 9: Industry, Innovation and Infrastructure',
            'GOAL 10: Reduced Inequality',
            'GOAL 11: Sustainable Cities and Communities',
            'GOAL 12: Responsible Consumption and Production',
            'GOAL 13: Climate Action',
            'GOAL 14: Life Below Water',
            'GOAL 15: Life on Land',
            'GOAL 16: Peace, Justice and Strong Institutions'
        ]
        # dictionary with label as key and percentage as value
        pred_dict = dict(zip(label_list, predictions))
    
        # sort 'pred_dict' by value and index the highest at [0]
        sorted_preds = sorted(pred_dict.items(), key=lambda x: x[1], reverse=True)
    
        # Make dataframe for plotly bar chart
        u, v = zip(*sorted_preds)
        m = list(u)
        n = list(v)
        df2 = pd.DataFrame()
        df2['SDG'] = m
        df2['Likelihood'] = n
    
        # plot graph of predictions
        fig = px.bar(df2, x="Likelihood", y="SDG", orientation="h")
    
        fig.update_layout(
            # barmode='stack', 
            template='seaborn', font=dict(family="Arial", size=12, color="black"),
            autosize=True,
            #width=800,
            #height=500,
            xaxis_title="Likelihood of SDG",
            yaxis_title="Sustainable development goals (SDG)",
            # legend_title="Topics"
        )
    
        fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
        fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
        fig.update_annotations(font_size=12)  # this changes y_axis, x_axis and subplot title font sizes
    
        # Make dataframe for plotly bar chart
        #df2 = pd.DataFrame(sorted_preds, columns=['SDG', 'Likelihood'])
    
        # Return the top prediction
        top_prediction = sorted_preds[0]

    # Return result
    return {top_prediction[0]: round(top_prediction[1], 3)}, fig

# Create Gradio interface for single text
iface2 = gr.Interface(fn=predict_sdg,
                      inputs=gr.Textbox(lines=7, label="Paste or type text here"), 
                      outputs=[gr.Label(label="Top SDG Predicted", show_label=True), gr.Plot(label="Likelihood of all SDG", show_label=True)], 
                      title="Single Text Prediction",
                      article="**Note:** The quality of model predictions may depend on the quality of information provided."
                     )

# UPLOAD CSV
# Define the prediction function
def predict_sdg_from_csv(file, progress=gr.Progress()):
    # Read the CSV file
    df_docs = pd.read_csv(file)
    text_list = df_docs["text_inputs"].tolist()

    # SDG labels list
    label_list = [
        'GOAL 1: No Poverty',
        'GOAL 2: Zero Hunger',
        'GOAL 3: Good Health and Well-being',
        'GOAL 4: Quality Education',
        'GOAL 5: Gender Equality',
        'GOAL 6: Clean Water and Sanitation',
        'GOAL 7: Affordable and Clean Energy',
        'GOAL 8: Decent Work and Economic Growth',
        'GOAL 9: Industry, Innovation and Infrastructure',
        'GOAL 10: Reduced Inequality',
        'GOAL 11: Sustainable Cities and Communities',
        'GOAL 12: Responsible Consumption and Production',
        'GOAL 13: Climate Action',
        'GOAL 14: Life Below Water',
        'GOAL 15: Life on Land',
        'GOAL 16: Peace, Justice and Strong Institutions'
    ]

    # Lists for appending predictions
    predicted_labels = []
    prediction_score = []

    # Preprocess text and make predictions
    for text_input in progress.tqdm(text_list, desc="Analysing data"):
        time.sleep(0.02)  # Sleep to avoid rate limiting
        cleaned_text = prep_text(text_input)
        tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
        text_logits = model(**tokenized_text).logits
        predictions = torch.softmax(text_logits, dim=1).tolist()[0]
        pred_dict = dict(zip(label_list, predictions))
        sorted_preds = sorted(pred_dict.items(), key=lambda g: g[1], reverse=True)
        predicted_labels.append(sorted_preds[0][0])
        prediction_score.append(sorted_preds[0][1])

    # Append predictions to the DataFrame
    df_docs['SDG_predicted'] = predicted_labels
    df_docs['prediction_score'] = prediction_score

    df_docs.to_csv('sdg_predictions.csv')
    output_csv = gr.File(value='sdg_predictions.csv', visible=True)

    # Create the histogram
    fig = px.histogram(df_docs, y="SDG_predicted")
    fig.update_layout(
        template='seaborn',
        font=dict(family="Arial", size=12, color="black"),
        autosize=True,
        #width=800,
        #height=500,
        xaxis_title="SDG counts",
        yaxis_title="Sustainable development goals (SDG)",
    )
    fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
    fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
    fig.update_annotations(font_size=12)

    return fig, output_csv

# Define the input component
file_input = gr.File(label="Upload CSV file here", show_label=True, file_types=[".csv"])

# Create the Gradio interface
iface3 = gr.Interface(fn=predict_sdg_from_csv, 
                      inputs= file_input, 
                      outputs=[gr.Plot(label='Frequency of SDGs', show_label=True), gr.File(label='Download output CSV', show_label=True)], 
                      title="Multi-text Prediction (CVS)",
                      description='**NOTE:** The column to be analysed must be titled ***text_inputs***')

demo = gr.TabbedInterface(interface_list = [iface1, iface2, iface3], 
                          tab_names = ["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"],
                          title = "Sustainble Development Goals (SDG) Text Classifier App",
                          theme = 'soft'
                         )

# Run the interface
demo.queue().launch()