File size: 4,886 Bytes
20f348c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import datetime
import logging
import time
import uuid

import click
from celery import shared_task  # type: ignore
from sqlalchemy import func, select
from sqlalchemy.orm import Session

from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from models.dataset import Dataset, Document, DocumentSegment
from services.vector_service import VectorService


@shared_task(queue="dataset")
def batch_create_segment_to_index_task(
    job_id: str,
    content: list,
    dataset_id: str,
    document_id: str,
    tenant_id: str,
    user_id: str,
):
    """
    Async batch create segment to index
    :param job_id:
    :param content:
    :param dataset_id:
    :param document_id:
    :param tenant_id:
    :param user_id:

    Usage: batch_create_segment_to_index_task.delay(segment_id)
    """
    logging.info(click.style("Start batch create segment jobId: {}".format(job_id), fg="green"))
    start_at = time.perf_counter()

    indexing_cache_key = "segment_batch_import_{}".format(job_id)

    try:
        with Session(db.engine) as session:
            dataset = session.get(Dataset, dataset_id)
            if not dataset:
                raise ValueError("Dataset not exist.")

            dataset_document = session.get(Document, document_id)
            if not dataset_document:
                raise ValueError("Document not exist.")

            if (
                not dataset_document.enabled
                or dataset_document.archived
                or dataset_document.indexing_status != "completed"
            ):
                raise ValueError("Document is not available.")
            document_segments = []
            embedding_model = None
            if dataset.indexing_technique == "high_quality":
                model_manager = ModelManager()
                embedding_model = model_manager.get_model_instance(
                    tenant_id=dataset.tenant_id,
                    provider=dataset.embedding_model_provider,
                    model_type=ModelType.TEXT_EMBEDDING,
                    model=dataset.embedding_model,
                )
            word_count_change = 0
            segments_to_insert: list[str] = []
            max_position_stmt = select(func.max(DocumentSegment.position)).where(
                DocumentSegment.document_id == dataset_document.id
            )
            max_position = session.scalar(max_position_stmt) or 1
            for segment in content:
                content_str = segment["content"]
                doc_id = str(uuid.uuid4())
                segment_hash = helper.generate_text_hash(content_str)
                # calc embedding use tokens
                tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0
                segment_document = DocumentSegment(
                    tenant_id=tenant_id,
                    dataset_id=dataset_id,
                    document_id=document_id,
                    index_node_id=doc_id,
                    index_node_hash=segment_hash,
                    position=max_position,
                    content=content_str,
                    word_count=len(content_str),
                    tokens=tokens,
                    created_by=user_id,
                    indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
                    status="completed",
                    completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
                )
                max_position += 1
                if dataset_document.doc_form == "qa_model":
                    segment_document.answer = segment["answer"]
                    segment_document.word_count += len(segment["answer"])
                word_count_change += segment_document.word_count
                session.add(segment_document)
                document_segments.append(segment_document)
                segments_to_insert.append(str(segment))  # Cast to string if needed
            # update document word count
            dataset_document.word_count += word_count_change
            session.add(dataset_document)
            # add index to db
            VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
            session.commit()

        redis_client.setex(indexing_cache_key, 600, "completed")
        end_at = time.perf_counter()
        logging.info(
            click.style(
                "Segment batch created job: {} latency: {}".format(job_id, end_at - start_at),
                fg="green",
            )
        )
    except Exception as e:
        logging.exception("Segments batch created index failed")
        redis_client.setex(indexing_cache_key, 600, "error")