import gradio as gr
from bs4 import BeautifulSoup
import json
import time
import os
from transformers import AutoTokenizer, pipeline

models = {
    "model_n1": "sileod/deberta-v3-base-tasksource-nli",
    # "model_n2": "roberta-large-mnli",
    # "model_n3": "facebook/bart-large-mnli",
    # "model_n4": "cross-encoder/nli-deberta-v3-xsmall"
}
def open_html(file):
    with open(file.name, "r") as f:
        content = f.read()
    return content

def find_form_fields(html_content):
    
    soup = BeautifulSoup(html_content, 'html.parser')
    
    # find all form tags
    forms = soup.find_all('form')
    
    form_fields = []
    
    for form in forms:
        # find all input and select tags within each form
        input_tags = form.find_all('input')
        select_tags = form.find_all('select')
        
        for tag in input_tags:
            form_fields.append(str(tag))
            
        for tag in select_tags:
            form_fields.append(str(tag))
    
    # Convert the list to a single string for display
    return form_fields

def load_json(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    return data

def classify_lines(text, candidate_labels, model_name):
    start_time = time.time()  # Start measuring time
    classifier = pipeline('zero-shot-classification', model=model_name)
    
    # Check if the text is already a list or if it needs splitting
    if isinstance(text, list):
        lines = text
    else:
        lines = text.split('\n')
        
    classified_lines = []
    for line in lines:
        if line.strip() and (line.strip().startswith("<input") or line.strip().startswith("<select") )and 'hidden' not in line.lower():  
            # Skip empty lines, classify lines starting with "<input", and exclude lines with 'hidden'
            results = classifier(line, candidate_labels=candidate_labels)
            top_classifications = results['labels'][:2]  # Get the top two classifications
            top_scores = results['scores'][:2]  # Get the top two scores
            classified_lines.append((line, list(zip(top_classifications, top_scores))))
    end_time = time.time()  # Stop measuring time
    execution_time = end_time - start_time  # Calculate execution time
    return classified_lines, execution_time

def classify_lines_json(text, json_content, candidate_labels, model_name, output_file_path):
    start_time = time.time()  # Start measuring time
    classifier = pipeline('zero-shot-classification', model=model_name)
    
    # Check if the text is already a list or if it needs splitting
    if isinstance(text, list):
        lines = text
    else:
        lines = text.split('\n')
        
    # Open the output.html file in write mode
    output_content = []

    with open(output_file_path, 'w') as output_file:
        for line in lines:

            if line.strip() and (line.strip().startswith("<input") or line.strip().startswith("<select") )and 'hidden' not in line.lower():  
                # Skip empty lines, classify lines starting with "<input", and exclude lines with 'hidden'
                results = classifier(line, candidate_labels=candidate_labels)
                top_classifications = results['labels'][:2]  # Get the top two classifications
                top_scores = results['scores'][:2]  # Get the top two scores
                line = line + f"<!-- Input: {json_content[top_classifications[0]]} with this certainty: {top_scores[0]} -->"
            output_file.write(line + '\n')  
            output_content.append(line + '\n')
          

    end_time = time.time()  # Stop measuring time
    execution_time = end_time - start_time  # Calculate execution time
    return output_content, execution_time
    
def retrieve_fields(data, path=''):
    """Recursively retrieve all fields from a given JSON structure and prompt for filling."""
    fields = {}

    # If the data is a dictionary
    if isinstance(data, dict):
        for key, value in data.items():
            # Construct the updated path for nested structures
            new_path = f"{path}.{key}" if path else key
            fields.update(retrieve_fields(value, new_path))
    
    # If the data is a list, iterate over its items
    elif isinstance(data, list):
        for index, item in enumerate(data):
            new_path = f"{path}[{index}]"
            fields.update(retrieve_fields(item, new_path))
    
    # If the data is a simple type (str, int, etc.)
    else:
        prompt = f"Please fill in the {path} field." if not data else data
        fields[path] = prompt

    return fields

def retrieve_fields_from_file(file_path):
    """Load JSON data from a file, then retrieve all fields and prompt for filling."""
    with open(file_path.name, 'r') as f:
        data = f.read()
    
    return retrieve_fields(json.loads(data))


def process_files(html_file, json_file):
    # This function will process the files.
    # Replace this with your own logic.
    output_file_path = "./output.html"
    # Open and read the files
    html_content = open_html(html_file)
    #print(html_content)
    html_inputs = find_form_fields(html_content)
    
    json_content = retrieve_fields_from_file(json_file)
    #Classificar os inputs do json para ver em que tipo de input ["text", "radio", "checkbox", "button", "date"]

    # Classify lines and measure execution time
    for model_name in models.values():
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        html_classified_lines, html_execution_time = classify_lines(html_inputs, ["text", "radio", "checkbox", "button", "date"], model_name)

        json_classified_lines, json_execution_time = classify_lines_json(html_content, json_content, list(json_content.keys()), model_name, output_file_path)

        # print(str(html_execution_time) + " - " + str(html_classified_lines))
        # print(str(json_execution_time) + " - " + str(json_classified_lines))
        #FILL HERE
          
        #print(type(json_classified_lines))
    # Assuming your function returns the processed HTML
    #json_classified_lines
    #return '\n'.join(map(str, html_classified_lines))
    return '\n'.join(map(str, json_classified_lines))

iface = gr.Interface(fn=process_files, 
                     inputs=[gr.inputs.File(label="Upload HTML File"), gr.inputs.File(label="Upload JSON File")], 
                     outputs="text",
                     examples=[
                        # ["./examples/form0.html", "./examples/form0_answer.json"],
                        ["./public/form1.html", "./public/form1_answer.json"],
                        ["./public/form2.html", "./public/form2_answer.json"],
                        ["./public/form3.html", "./public/form3_answer.json"],
                        ["./public/form4.html", "./public/form4_answer.json"]
                    ])
                   

iface.launch()