"""
Contains Utility functions for LLM and Database module. Along with some other misllaneous functions.
"""

from turtle import clear
from pymupdf import pymupdf
#from docx import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
#import tiktoken
import base64
import hashlib
from typing import List
from openai import OpenAI
#from dotenv import load_dotenv
import os
import hashlib
from datetime import datetime
from typing import List, Optional, Dict, Any, Tuple

def generate_file_id(file_bytes: bytes) -> str:
    """Generate a 4-character unique file ID for given file."""
    hash_obj = hashlib.sha256()
    hash_obj.update(file_bytes[:4096])  # Still hash the first 4096 bytes
    # Take first 2 bytes (16 bits) and convert to base36 (alphanumeric)
    file_id = hex(int.from_bytes(hash_obj.digest()[:2], 'big'))[2:].zfill(4)
    return file_id


def process_pdf_to_chunks(
    pdf_content: bytes,
    file_name: str,
    chunk_size: int = 512,
    chunk_overlap: int = 20
) -> Tuple[List[Dict[str, Any]], str]:
    """
    Process PDF content into chunks with column layout detection and proper image handling
    """
    doc = pymupdf.open(stream=pdf_content, filetype="pdf")
    document_text = ""
    all_images = []
    image_positions = []
    char_to_page_map = []
    layout_info = {}
    
    doc_id = generate_file_id(pdf_content)

    def detect_columns(blocks):
        """Detect if page has multiple columns based on text block positions"""
        if not blocks:
            return 1
            
        x_positions = [block[0] for block in blocks]
        x_positions.sort()
        
        if len(x_positions) > 1:
            gaps = [x_positions[i+1] - x_positions[i] for i in range(len(x_positions)-1)]
            significant_gaps = [gap for gap in gaps if gap > page.rect.width * 0.15]
            return len(significant_gaps) + 1
        return 1

    def sort_blocks_by_position(blocks, num_columns):
        """Sort blocks by column and vertical position"""
        if num_columns == 1:
            return sorted(blocks, key=lambda b: b[0][1])  # b[0] is the bbox tuple, b[0][1] is y coordinate
        
        page_width = page.rect.width
        column_width = page_width / num_columns
        
        def get_column(block):
            bbox = block[0]  # Get the bounding box tuple
            x_coord = bbox[0]  # Get the x coordinate (first element)
            return int(x_coord // column_width)
            
        return sorted(blocks, key=lambda b: (get_column(b), b[0][1]))

    # Process each page
    for page_num, page in enumerate(doc, 1):
        blocks = page.get_text_blocks()
        images = page.get_images()
        
        # Detect layout
        num_columns = detect_columns(blocks)
        layout_info[page_num] = {
            "columns": num_columns,
            "width": page.rect.width,
            "height": page.rect.height
        }
        
        # Create elements list with both text and images
        elements = [(block[:4], block[4], "text") for block in blocks]
        
        # Add images to elements

        for img in images:
            try:
                img_rects = page.get_image_rects(img[0])
                if img_rects and len(img_rects) > 0:
                    img_bbox = img_rects[0]
                    if img_bbox:
                        img_data = (img_bbox, img[0], "image")
                        elements.append(img_data)
            except Exception as e:
                print(f"Error processing image: {e}")
                continue
        
        # Sort elements by position
        sorted_elements = sort_blocks_by_position(elements, num_columns)
        
        # Process elements in order
        page_text = ""
        for element in sorted_elements:
            if element[2] == "text":
                text_content = element[1]
                page_text += text_content
                char_to_page_map.extend([page_num] * len(text_content))
            else:
                xref = element[1]
                base_image = doc.extract_image(xref)
                image_bytes = base_image["image"]
                # Convert image bytes to base64
                image_base64 = base64.b64encode(image_bytes).decode('utf-8')
                all_images.append(image_base64)  # Store base64 encoded image
                
                image_marker = f"\n<img_{len(all_images)-1}>\n"
                image_positions.append((len(all_images)-1, len(document_text) + len(page_text)))
                page_text += image_marker
                char_to_page_map.extend([page_num] * len(image_marker))
        
        document_text += page_text

    # Create chunks
    splitter = RecursiveCharacterTextSplitter(
        #separators=["\n\n", "\n", " ", ""],
        #keep_separator=True
    ).from_tiktoken_encoder(
        encoding_name="cl100k_base",
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )
    
    text_chunks = splitter.split_text(document_text)
    
    # Process chunks with metadata
    processed_chunks = []
    for chunk_idx, chunk in enumerate(text_chunks):
        chunk_start = document_text.find(chunk)
        chunk_end = chunk_start + len(chunk)
        
        # Get page range and layout info
        chunk_pages = sorted(set(char_to_page_map[chunk_start:chunk_end]))
        chunk_layouts = {page: layout_info[page] for page in chunk_pages}
        
        # Get images for this chunk
        chunk_images = []
        for img_idx, img_pos in image_positions:
            if chunk_start <= img_pos <= chunk_end:
                chunk_images.append(all_images[img_idx])  # Already base64 encoded
        
        # Clean the chunk text
        #cleaned_chunk = clean_text_for_llm(chunk)

        chunk_dict = {
            "text": chunk,
            "metadata": {
                "created_date": datetime.now().isoformat(),
                "file_name": file_name,
                "images": chunk_images,
                "document_id": doc_id,
                "location": {
                    "char_start": chunk_start,
                    "char_end": chunk_end,
                    "pages": chunk_pages,
                    "chunk_index": chunk_idx,
                    "total_chunks": len(text_chunks),
                    "layout": chunk_layouts
                }
            }
        }
        processed_chunks.append(chunk_dict)
    
    return processed_chunks, doc_id



# import re
# import unicodedata
# from typing import Optional

# # Compile regex patterns once
# HTML_TAG_PATTERN = re.compile(r'<[^>]+>')
# MULTIPLE_NEWLINES = re.compile(r'\n\s*\n')
# MULTIPLE_SPACES = re.compile(r'\s+')

# def clean_text_for_llm(text: Optional[str]) -> str:
#     """
#     Efficiently clean and normalize text for LLM processing.
#     """
#     # Early returns
#     if not text:
#         return ""
#     if not isinstance(text, str):
#         try:
#             text = str(text)
#         except Exception:
#             return ""

#     # Single-pass character filtering
#     chars = []
#     prev_char = ''
#     space_pending = False
    
    # for char in text:
    #     # Skip null bytes and most control characters
    #     if char == '\0' or unicodedata.category(char).startswith('C'):
    #         if char not in '\n\t':
    #             continue
        
    #     # Convert escaped sequences
    #     if prev_char == '\\':
    #         if char == 'n':
    #             chars[-1] = '\n'
    #             continue
    #         if char == 't':
    #             chars[-1] = '\t'
    #             continue
            
    #     # Handle whitespace
    #     if char.isspace():
    #         if not space_pending:
    #             space_pending = True
    #         continue
            
    #     if space_pending:
    #         chars.append(' ')
    #         space_pending = False
            
    #     chars.append(char)
    #     prev_char = char

    # # Join characters and perform remaining operations
    # text = ''.join(chars)
    
    # # Remove HTML tags
    # #text = HTML_TAG_PATTERN.sub('', text)
    
    # # Normalize Unicode in a single pass
    # text = unicodedata.normalize('NFKC', text)
    
    # # Clean up newlines
    # text = MULTIPLE_NEWLINES.sub('\n', text)
    
    # Final trim
    # return text.strip()