import logging
import time

import click
from celery import shared_task

from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument


@shared_task(queue="dataset")
def deal_dataset_vector_index_task(dataset_id: str, action: str):
    """
    Async deal dataset from index
    :param dataset_id: dataset_id
    :param action: action
    Usage: deal_dataset_vector_index_task.delay(dataset_id, action)
    """
    logging.info(click.style("Start deal dataset vector index: {}".format(dataset_id), fg="green"))
    start_at = time.perf_counter()

    try:
        dataset = Dataset.query.filter_by(id=dataset_id).first()

        if not dataset:
            raise Exception("Dataset not found")
        index_type = dataset.doc_form
        index_processor = IndexProcessorFactory(index_type).init_index_processor()
        if action == "remove":
            index_processor.clean(dataset, None, with_keywords=False)
        elif action == "add":
            dataset_documents = (
                db.session.query(DatasetDocument)
                .filter(
                    DatasetDocument.dataset_id == dataset_id,
                    DatasetDocument.indexing_status == "completed",
                    DatasetDocument.enabled == True,
                    DatasetDocument.archived == False,
                )
                .all()
            )

            if dataset_documents:
                dataset_documents_ids = [doc.id for doc in dataset_documents]
                db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
                    {"indexing_status": "indexing"}, synchronize_session=False
                )
                db.session.commit()

                for dataset_document in dataset_documents:
                    try:
                        # add from vector index
                        segments = (
                            db.session.query(DocumentSegment)
                            .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
                            .order_by(DocumentSegment.position.asc())
                            .all()
                        )
                        if segments:
                            documents = []
                            for segment in segments:
                                document = Document(
                                    page_content=segment.content,
                                    metadata={
                                        "doc_id": segment.index_node_id,
                                        "doc_hash": segment.index_node_hash,
                                        "document_id": segment.document_id,
                                        "dataset_id": segment.dataset_id,
                                    },
                                )

                                documents.append(document)
                            # save vector index
                            index_processor.load(dataset, documents, with_keywords=False)
                        db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
                            {"indexing_status": "completed"}, synchronize_session=False
                        )
                        db.session.commit()
                    except Exception as e:
                        db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
                            {"indexing_status": "error", "error": str(e)}, synchronize_session=False
                        )
                        db.session.commit()
        elif action == "update":
            dataset_documents = (
                db.session.query(DatasetDocument)
                .filter(
                    DatasetDocument.dataset_id == dataset_id,
                    DatasetDocument.indexing_status == "completed",
                    DatasetDocument.enabled == True,
                    DatasetDocument.archived == False,
                )
                .all()
            )
            # add new index
            if dataset_documents:
                # update document status
                dataset_documents_ids = [doc.id for doc in dataset_documents]
                db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
                    {"indexing_status": "indexing"}, synchronize_session=False
                )
                db.session.commit()

                # clean index
                index_processor.clean(dataset, None, with_keywords=False)

                for dataset_document in dataset_documents:
                    # update from vector index
                    try:
                        segments = (
                            db.session.query(DocumentSegment)
                            .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
                            .order_by(DocumentSegment.position.asc())
                            .all()
                        )
                        if segments:
                            documents = []
                            for segment in segments:
                                document = Document(
                                    page_content=segment.content,
                                    metadata={
                                        "doc_id": segment.index_node_id,
                                        "doc_hash": segment.index_node_hash,
                                        "document_id": segment.document_id,
                                        "dataset_id": segment.dataset_id,
                                    },
                                )

                                documents.append(document)
                            # save vector index
                            index_processor.load(dataset, documents, with_keywords=False)
                        db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
                            {"indexing_status": "completed"}, synchronize_session=False
                        )
                        db.session.commit()
                    except Exception as e:
                        db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
                            {"indexing_status": "error", "error": str(e)}, synchronize_session=False
                        )
                        db.session.commit()

        end_at = time.perf_counter()
        logging.info(
            click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
        )
    except Exception:
        logging.exception("Deal dataset vector index failed")