import glob
import pandas as pd
import json
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
from transformers import AutoTokenizer
from torch import cuda
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from auditqa.reports import files, report_list
from langchain.docstore.document import Document
import configparser

# read all the necessary variables
device = 'cuda' if cuda.is_available() else 'cpu'
path_to_data = "./reports/"       


##---------------------functions -------------------------------------------##
def getconfig(configfile_path:str):
    """
    Read the config file

    Params
    ----------------
    configfile_path: file path of .cfg file
    """

    config = configparser.ConfigParser()

    try:
        config.read_file(open(configfile_path))
        return config
    except:
        logging.warning("config file not found")
        
def open_file(filepath):
    with open(filepath) as file:
        simple_json = json.load(file)
    return simple_json

def load_chunks():
    """
    this method reads through the files and report_list to create the vector database
    """

    #  we iterate through the files which contain information about its
    # 'source'=='category', 'subtype', these are used in UI for document selection
    #  which will be used later for filtering database
    config = getconfig("./model_params.cfg")
  
    doc_processed = open_file(path_to_data + "docling_chunks.json" ) 
    chunks_list = []

    for doc in doc_processed:
        chunks_list.append(Document(page_content= doc['content'], 
                 metadata=doc['metadata']
                                    ))

    # define embedding model
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        show_progress= True,
        encode_kwargs = {'normalize_embeddings': bool(int(config.get('retriever','NORMALIZE'))),
                        'batch_size':100},
        model_name=config.get('retriever','MODEL')
    )
    # placeholder for collection
    qdrant_collections = {}
    print("embeddings started")
    #batch_size = 1000  # Adjust this value based on your system's memory capacity
    #for i in range(0, len(chunks_list), batch_size):
    #    print("embedding",(i+batch_size)/1000)
    #    batch_docs = chunks_list[i:i+batch_size]
    #    qdrant = Qdrant.from_documents(
    #        batch_docs, embeddings,
    #        path="/data/local_qdrant",
    #       recreate_collection=False,
    #        collection_name='reportsFeb2025',
    #   )
        
    qdrant_collections['docling'] = Qdrant.from_documents(
                chunks_list,
                embeddings,
                path="/data/local_qdrant",
                collection_name='docling',
            )
    print(qdrant_collections)
    print("vector embeddings done")
    return qdrant_collections

def load_old_chunks():
    """
    this method reads through the files and report_list to create the vector database
    """

    #  we iterate through the files which contain information about its
    # 'source'=='category', 'subtype', these are used in UI for document selection
    #  which will be used later for filtering database
    config = getconfig("./model_params.cfg")
    files = pd.read_json("./axa_processed_chunks_update.json")
    all_documents= []
    # iterate through 'source'
    for i in range(len(files)):
        # load the chunks
        try:
            doc_processed = open_file(path_to_data + "/chunks/"+ os.path.basename(files.loc[i,'chunks_filepath']))
            doc_processed = doc_processed['paragraphs']

        except Exception as e:
            print("Exception: ", e)
        print("chunks in subtype:", files.loc[i,'filename'], "are:",len(doc_processed))

        # add metadata information 
        
        for doc in doc_processed:
            all_documents.append(Document(page_content= str(doc['content']), 
                        metadata={"source": files.loc[i,'category'],
                                "subtype":os.path.splitext(files.loc[i,'filename'])[0],
                                "year":str(files.loc[i,'year']),
                                "filename":files.loc[0,'filename'],
                                "page":doc['metadata']['page'],
                                "headings":doc['metadata']['headings']}))
    
    # convert list of list to flat list
    print("length of chunks:",len(all_documents))

    # define embedding model
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        encode_kwargs = {'normalize_embeddings': bool(int(config.get('retriever','NORMALIZE')))},
        model_name=config.get('retriever','MODEL')
    )
    # placeholder for collection
    qdrant_collections = {}  
    qdrant_collections['allreports'] = Qdrant.from_documents(
                all_documents,
                embeddings,
                path="/data/local_qdrant",
                collection_name='allreports',
            )
    print(qdrant_collections)
    print("vector embeddings done")
    return qdrant_collections

def get_local_qdrant(): 
    """once the local qdrant server is created this is used to make the connection to exisitng server"""
    config = getconfig("./model_params.cfg")
    qdrant_collections = {}
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        encode_kwargs = {'normalize_embeddings': True},
        model_name=config.get('retriever','MODEL'))
    client = QdrantClient(path="/data/local_qdrant") 
    print("Collections in local Qdrant:",client.get_collections())
    qdrant_collections['docling'] = Qdrant(client=client, collection_name='docling', embeddings=embeddings, )
    return qdrant_collections