import spacy
from spacy.language import Language
from spacy.lang.it import Italian
import re
from transformers import pipeline
from gradio.inputs import File
import gradio as gr
from pdf2image import convert_from_path
import pytesseract
import tempfile
import os
from gradio.inputs import Dropdown
import gradio as gr
import tempfile
import os
from pdf2image import convert_from_path
import pytesseract
import fitz
from pdf2image import convert_from_bytes


def preprocess_punctuation(text):
  pattern = r'(?<![a-z])[a-zA-Z\.]{1,4}(?:\.[a-zA-Z\.]{1,4})*\.(?!\s*[A-Z])'
  matches = re.findall(pattern, text)
  res = [*set(matches)]
  #res = [r for r in res if not nlp(r).ents or 
       #not any(ent.label_ in nlp.get_pipe('ner').labels for ent in nlp(r).ents)] #optimized
  return res


def preprocess_text(text):
  prep_text = re.sub(r'\n\s*\n', '\n', text)
  prep_text = re.sub(r'\n{2,}', '\n', prep_text)
#string_with_single_newlines_and_no_blank_lines = re.sub(r' {2,}', ' ', string_with_single_newlines_and_no_blank_lines)
#print(string_with_single_newlines_and_no_blank_lines)
  return prep_text



@Language.component('custom_tokenizer')
def custom_tokenizer(doc):
    # Define a custom rule to ignore colons as a sentence boundary
    for token in doc[:-1]:
        if (token.text == ":"):
            doc[token.i+1].is_sent_start = False
    return doc



def get_sentences(text, dictionary = None):
  cl_sentences = []
  chars_to_strip = [' ', '\n']
  chars_to_strip_str = ''.join(set(chars_to_strip))
  nlp = spacy.load("it_core_news_lg")  #load ita moodel
  nlp.add_pipe("custom_tokenizer", before="parser")

  for punct in preprocess_punctuation(text):
    nlp.tokenizer.add_special_case(punct, [{spacy.symbols.ORTH: punct, spacy.symbols.NORM: punct}])

  doc = nlp(text)  # Process the text with spaCy
  sentences = list(doc.sents)  # Split the text into sentences
  for sentence in sentences:
    sent = sentence.text
    cl_sentence = ' '.join(filter(None, sent.lstrip(chars_to_strip_str).rstrip(chars_to_strip_str).split(' ')))
    if cl_sentence!= '':
      cl_sentences.append(cl_sentence)
  return cl_sentences




def extract_numbers(text, given_strings):
    # Split text into a list of words
    words = text.split()
    # Find the indices of the given strings in the list of words
    indices = [i for i, word in enumerate(words) if any(s in word for s in given_strings)]
    # Initialize an empty list to store the numbers
    numbers = []
    # Loop through each index
    for index in indices:
        # Define the range of words to search for numbers
        start = max(index - 1, 0)
        end = min(index + 2, len(words))
        # Extract the words within the range
        context = words[start:end]
        # Check if the context contains mathematical operators
        if any(re.match(r'[+\*/]', word) for word in context):
            continue
        # Find all numbers in the context
        context_numbers = [
            float(re.sub('[^0-9\.,]+', '', word).replace(',', '.'))
            if re.sub('[^0-9\.,]+', '', word).replace(',', '.').replace('.', '', 1).isdigit()
            else int(re.sub('[^0-9]+', '', word))
            if re.sub('[^0-9]+', '', word).isdigit()
            else None
            for word in context
        ]
        # Add the numbers to the list
        numbers.extend(context_numbers)
    return numbers



def get_text_and_values(text, key_list):
  sentences = get_sentences(text)
  total_numbers= []
  infoDict = {}
  for sentence in sentences:
    numbers = extract_numbers(text = sentence, given_strings = key_list)
    total_numbers.append(numbers)
    if not numbers:
      continue
    else: infoDict[sentence] = numbers
  return infoDict


def get_useful_text(dictionary):
  keysList = list(dictionary.keys())
  tx = ('\n------------------------\n'.join(keysList))
  return tx

def get_values(dictionary):
  pr = list(dictionary.values())
  return pr


def initialize_qa_transformer(model):
  qa = pipeline("text2text-generation", model=model)
  return qa


def get_answers_unfiltered(dictionary, question, qa_pipeline):
  keysList = list(dictionary.keys())
  answers = []
  for kl in keysList:
    answer = qa_pipeline(f'{kl} Domanda: {question}')
    answers.append(answer)
  return answers


def get_total(answered_values, text, keywords, raw_values, unique_values = False):
  numeric_list = [num for sublist in raw_values for num in sublist if isinstance(num, (int, float))]
  #numbers = [float(x[0]['generated_text']) for x in answered_values if x[0]['generated_text'].isdigit()]
  pattern = r'\d+(?:[.,]\d+)?'
  numbers = []
  for sub_lst in answered_values:
      for d in sub_lst:
          for k, v in d.items():
            # Replace commas with dots
              v = v.replace(',', '.')
            # Extract numbers and convert to float
              numbers += [float(match) for match in re.findall(pattern, v) if (float(match) >= 5.0) and (float(match) in numeric_list)]
  ###### remove duplicates
  if unique_values:
    numbers = list(set(numbers))
  ######
  total = 0
  sum = 0
  total_list = []
# Define a regular expression pattern that will match a number
  pattern = r'\d+'
# Loop through the keywords and search for them in the text
  found = False
  for keyword in keywords:
    # Build a regular expression pattern that looks for the keyword
    # followed by up to three words, then a number
      keyword_pattern = f'{keyword}(\\s+\\w+){{0,3}}\\s+({pattern})'
      match = re.search(keyword_pattern, text, re.IGNORECASE)
      if match:
        # If we find a match, print the number and set found to True
          number = match.group(2)
          if (number in numbers) and (number in numeric_list):
            total_list.append(int(number))
            print(f"Found a value ({number}) for keyword '{keyword}'.")
            found = True  

# If we didn't find a match
  if not found:
    for value in numbers:
      if value in numeric_list:
        total += value
    total_list.append(total)
 #If there is more than one total, it means different lots with many total measures for each house. Calculate the sum of the totals mq  
  for value in total_list:
    sum += value
  return numbers, sum



def extractor_clean(text, k_words, transformer, question, total_kwords, return_text = False):

  tex = ''
  dictionary = get_text_and_values(text, k_words)
  raw = get_values(dictionary)
  qa = initialize_qa_transformer(transformer)
  val = get_answers_unfiltered(dictionary, question = question, qa_pipeline = qa)
  keywords = ['totale', 'complessivo', 'complessiva']
  values = get_total(answered_values= val, raw_values = raw, text = text, keywords = total_kwords, unique_values = True)
  if return_text:
    tex = get_useful_text(dictionary)
    return values, return_text, tex
  elif return_text == False:
    return values, return_text



def pdf_ocr(file, model_t, question):
    # Convert PDF to image
    with tempfile.TemporaryDirectory() as path:
        with open(file, "rb") as f:
            content = f.read()

        with fitz.open(stream=content, filetype="pdf") as doc:
            num_pages = len(doc)

            # Extract text from the PDF
            text = ""
            for page in doc:
                text += page.get_text()

            # Perform OCR on the PDF if the extracted text is empty
            if not text:
                # Convert PDF pages to images
                images = convert_from_bytes(content)
                for i, img in enumerate(images):
                    text += pytesseract.image_to_string(img, lang='ita')

                # Clear the image list to free up memory
                del images

    ks = ('mq', 'MQ', 'Mq', 'metri quadri', 'm2')
    quest = "Quanti metri quadri misura la superficie?"
    totalK = ['totale', 'complessivo', 'complessiva']

    extracted_values = extractor_clean(text=text, k_words=ks, transformer=model_t, question=question, total_kwords=totalK, return_text=True)
    values_output = extracted_values[0][0]
    sor_values = sorted(values_output)
    total_output = f'{extracted_values[0][1]}  Mq'
    text_output = extracted_values[2]

    immobile_values = [f'{i + 1}. Immobile :  {value}  Mq\n' for i, value in enumerate(sor_values)]
    immobile_values = '\n'.join(immobile_values)

    return immobile_values, total_output, text_output


def ocr_interface(pdf_file, model_t='it5/it5-base-question-answering', question="Quanti metri quadri misura l'immobile?"):
    # Call the pdf_ocr function
    values, total, text = pdf_ocr(pdf_file.name, model_t, question)
    return values, total, text


# Start the UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:

    gr.Markdown(
    '''
    # PDF Mq Extractor
    Demo for ITAL-IA
    ''')
    with gr.Tab("Extractor"):
      with gr.Row():
        pdf_input = gr.components.File(label="PDF File")
     
      with gr.Row():
          model_input = gr.components.Dropdown(['it5/it5-base-question-answering', 'it5/it5-small-question-answering'],
                                               value='it5/it5-base-question-answering', label = 'Select model')
          question_input = gr.components.Dropdown(["Quanti metri quadri misura l'immobile?"],
                                                  value = "Quanti metri quadri misura l'immobile?", label = 'Question')
      
      with gr.Column():
          gr.Markdown(
          '''
          # Output values
          Values extracted from the pdf document
          ''')
      
      with gr.Row():

          text_output = gr.components.Textbox(label="Ref. Text")
          values_output = gr.components.Textbox(label="Area Values - sorted by value")
          total_output = gr.components.Textbox(label="Total")
          
      with gr.Row():
          extract_button = gr.Button("Extract")


    extract_button.click(fn = ocr_interface,
                         inputs=[pdf_input, model_input, question_input], outputs=[values_output, total_output, text_output])

    gr.Examples(['Example1(scannedDoc).pdf', 'Example2.pdf', 'Example3Large.pdf'], inputs = pdf_input, 
                cache_examples = True, fn = ocr_interface, outputs = [values_output, total_output, text_output])


demo.launch()