KevinHuSh
commited on
Commit
·
2436df2
1
Parent(s):
368b624
add raptor (#899)
Browse files### What problem does this PR solve?
#882
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/system_app.py +2 -1
- api/db/services/document_service.py +27 -3
- api/db/services/llm_service.py +4 -0
- api/db/services/task_service.py +2 -1
- rag/llm/chat_model.py +2 -3
- rag/raptor.py +114 -0
- rag/svr/task_executor.py +82 -27
- rag/utils/redis_conn.py +11 -9
api/apps/system_app.py
CHANGED
@@ -60,7 +60,8 @@ def status():
|
|
60 |
st = timer()
|
61 |
try:
|
62 |
qinfo = REDIS_CONN.health(SVR_QUEUE_NAME)
|
63 |
-
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.),
|
|
|
64 |
except Exception as e:
|
65 |
res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
66 |
|
|
|
60 |
st = timer()
|
61 |
try:
|
62 |
qinfo = REDIS_CONN.health(SVR_QUEUE_NAME)
|
63 |
+
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.),
|
64 |
+
"pending": qinfo.get("pending", 0)}
|
65 |
except Exception as e:
|
66 |
res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
67 |
|
api/db/services/document_service.py
CHANGED
@@ -18,8 +18,10 @@ from datetime import datetime
|
|
18 |
from elasticsearch_dsl import Q
|
19 |
from peewee import fn
|
20 |
|
|
|
21 |
from api.settings import stat_logger
|
22 |
-
from api.utils import current_timestamp, get_format_time
|
|
|
23 |
from rag.utils.es_conn import ELASTICSEARCH
|
24 |
from rag.utils.minio_conn import MINIO
|
25 |
from rag.nlp import search
|
@@ -30,6 +32,7 @@ from api.db.db_models import Document
|
|
30 |
from api.db.services.common_service import CommonService
|
31 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
32 |
from api.db import StatusEnum
|
|
|
33 |
|
34 |
|
35 |
class DocumentService(CommonService):
|
@@ -110,7 +113,7 @@ class DocumentService(CommonService):
|
|
110 |
@classmethod
|
111 |
@DB.connection_context()
|
112 |
def get_unfinished_docs(cls):
|
113 |
-
fields = [cls.model.id, cls.model.process_begin_at]
|
114 |
docs = cls.model.select(*fields) \
|
115 |
.where(
|
116 |
cls.model.status == StatusEnum.VALID.value,
|
@@ -260,7 +263,12 @@ class DocumentService(CommonService):
|
|
260 |
prg = -1
|
261 |
status = TaskStatus.FAIL.value
|
262 |
elif finished:
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
msg = "\n".join(msg)
|
266 |
info = {
|
@@ -282,3 +290,19 @@ class DocumentService(CommonService):
|
|
282 |
return len(cls.model.select(cls.model.id).where(
|
283 |
cls.model.kb_id == kb_id).dicts())
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
from elasticsearch_dsl import Q
|
19 |
from peewee import fn
|
20 |
|
21 |
+
from api.db.db_utils import bulk_insert_into_db
|
22 |
from api.settings import stat_logger
|
23 |
+
from api.utils import current_timestamp, get_format_time, get_uuid
|
24 |
+
from rag.settings import SVR_QUEUE_NAME
|
25 |
from rag.utils.es_conn import ELASTICSEARCH
|
26 |
from rag.utils.minio_conn import MINIO
|
27 |
from rag.nlp import search
|
|
|
32 |
from api.db.services.common_service import CommonService
|
33 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
34 |
from api.db import StatusEnum
|
35 |
+
from rag.utils.redis_conn import REDIS_CONN
|
36 |
|
37 |
|
38 |
class DocumentService(CommonService):
|
|
|
113 |
@classmethod
|
114 |
@DB.connection_context()
|
115 |
def get_unfinished_docs(cls):
|
116 |
+
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg]
|
117 |
docs = cls.model.select(*fields) \
|
118 |
.where(
|
119 |
cls.model.status == StatusEnum.VALID.value,
|
|
|
263 |
prg = -1
|
264 |
status = TaskStatus.FAIL.value
|
265 |
elif finished:
|
266 |
+
if d["parser_config"].get("raptor") and d["progress_msg"].lower().find(" raptor")<0:
|
267 |
+
queue_raptor_tasks(d)
|
268 |
+
prg *= 0.98
|
269 |
+
msg.append("------ RAPTOR -------")
|
270 |
+
else:
|
271 |
+
status = TaskStatus.DONE.value
|
272 |
|
273 |
msg = "\n".join(msg)
|
274 |
info = {
|
|
|
290 |
return len(cls.model.select(cls.model.id).where(
|
291 |
cls.model.kb_id == kb_id).dicts())
|
292 |
|
293 |
+
|
294 |
+
def queue_raptor_tasks(doc):
|
295 |
+
def new_task():
|
296 |
+
nonlocal doc
|
297 |
+
return {
|
298 |
+
"id": get_uuid(),
|
299 |
+
"doc_id": doc["id"],
|
300 |
+
"from_page": 0,
|
301 |
+
"to_page": -1,
|
302 |
+
"progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)."
|
303 |
+
}
|
304 |
+
|
305 |
+
task = new_task()
|
306 |
+
bulk_insert_into_db(Task, [task], True)
|
307 |
+
task["type"] = "raptor"
|
308 |
+
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
api/db/services/llm_service.py
CHANGED
@@ -155,6 +155,10 @@ class LLMBundle(object):
|
|
155 |
tenant_id, llm_type, llm_name, lang=lang)
|
156 |
assert self.mdl, "Can't find mole for {}/{}/{}".format(
|
157 |
tenant_id, llm_type, llm_name)
|
|
|
|
|
|
|
|
|
158 |
|
159 |
def encode(self, texts: list, batch_size=32):
|
160 |
emd, used_tokens = self.mdl.encode(texts, batch_size)
|
|
|
155 |
tenant_id, llm_type, llm_name, lang=lang)
|
156 |
assert self.mdl, "Can't find mole for {}/{}/{}".format(
|
157 |
tenant_id, llm_type, llm_name)
|
158 |
+
self.max_length = 512
|
159 |
+
for lm in LLMService.query(llm_name=llm_name):
|
160 |
+
self.max_length = lm.max_tokens
|
161 |
+
break
|
162 |
|
163 |
def encode(self, texts: list, batch_size=32):
|
164 |
emd, used_tokens = self.mdl.encode(texts, batch_size)
|
api/db/services/task_service.py
CHANGED
@@ -53,6 +53,7 @@ class TaskService(CommonService):
|
|
53 |
Knowledgebase.embd_id,
|
54 |
Tenant.img2txt_id,
|
55 |
Tenant.asr_id,
|
|
|
56 |
cls.model.update_time]
|
57 |
docs = cls.model.select(*fields) \
|
58 |
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
@@ -159,4 +160,4 @@ def queue_tasks(doc, bucket, name):
|
|
159 |
DocumentService.begin2parse(doc["id"])
|
160 |
|
161 |
for t in tsks:
|
162 |
-
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
|
|
|
53 |
Knowledgebase.embd_id,
|
54 |
Tenant.img2txt_id,
|
55 |
Tenant.asr_id,
|
56 |
+
Tenant.llm_id,
|
57 |
cls.model.update_time]
|
58 |
docs = cls.model.select(*fields) \
|
59 |
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
|
|
160 |
DocumentService.begin2parse(doc["id"])
|
161 |
|
162 |
for t in tsks:
|
163 |
+
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
|
rag/llm/chat_model.py
CHANGED
@@ -57,8 +57,7 @@ class Base(ABC):
|
|
57 |
stream=True,
|
58 |
**gen_conf)
|
59 |
for resp in response:
|
60 |
-
if
|
61 |
-
if not resp.choices[0].delta.content:continue
|
62 |
ans += resp.choices[0].delta.content
|
63 |
total_tokens += 1
|
64 |
if resp.choices[0].finish_reason == "length":
|
@@ -379,7 +378,7 @@ class VolcEngineChat(Base):
|
|
379 |
ans += resp.choices[0].message.content
|
380 |
yield ans
|
381 |
if resp.choices[0].finish_reason == "stop":
|
382 |
-
|
383 |
|
384 |
except Exception as e:
|
385 |
yield ans + "\n**ERROR**: " + str(e)
|
|
|
57 |
stream=True,
|
58 |
**gen_conf)
|
59 |
for resp in response:
|
60 |
+
if not resp.choices or not resp.choices[0].delta.content:continue
|
|
|
61 |
ans += resp.choices[0].delta.content
|
62 |
total_tokens += 1
|
63 |
if resp.choices[0].finish_reason == "length":
|
|
|
378 |
ans += resp.choices[0].message.content
|
379 |
yield ans
|
380 |
if resp.choices[0].finish_reason == "stop":
|
381 |
+
yield resp.usage.total_tokens
|
382 |
|
383 |
except Exception as e:
|
384 |
yield ans + "\n**ERROR**: " + str(e)
|
rag/raptor.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
import re
|
17 |
+
import traceback
|
18 |
+
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
|
19 |
+
from threading import Lock
|
20 |
+
from typing import Tuple
|
21 |
+
import umap
|
22 |
+
import numpy as np
|
23 |
+
from sklearn.mixture import GaussianMixture
|
24 |
+
|
25 |
+
from rag.utils import num_tokens_from_string, truncate
|
26 |
+
|
27 |
+
|
28 |
+
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
29 |
+
def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=256, threshold=0.1):
|
30 |
+
self._max_cluster = max_cluster
|
31 |
+
self._llm_model = llm_model
|
32 |
+
self._embd_model = embd_model
|
33 |
+
self._threshold = threshold
|
34 |
+
self._prompt = prompt
|
35 |
+
self._max_token = max_token
|
36 |
+
|
37 |
+
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state:int):
|
38 |
+
max_clusters = min(self._max_cluster, len(embeddings))
|
39 |
+
n_clusters = np.arange(1, max_clusters)
|
40 |
+
bics = []
|
41 |
+
for n in n_clusters:
|
42 |
+
gm = GaussianMixture(n_components=n, random_state=random_state)
|
43 |
+
gm.fit(embeddings)
|
44 |
+
bics.append(gm.bic(embeddings))
|
45 |
+
optimal_clusters = n_clusters[np.argmin(bics)]
|
46 |
+
return optimal_clusters
|
47 |
+
|
48 |
+
def __call__(self, chunks: Tuple[str, np.ndarray], random_state, callback=None):
|
49 |
+
layers = [(0, len(chunks))]
|
50 |
+
start, end = 0, len(chunks)
|
51 |
+
if len(chunks) <= 1: return
|
52 |
+
|
53 |
+
def summarize(ck_idx, lock):
|
54 |
+
nonlocal chunks
|
55 |
+
try:
|
56 |
+
texts = [chunks[i][0] for i in ck_idx]
|
57 |
+
len_per_chunk = int((self._llm_model.max_length - self._max_token)/len(texts))
|
58 |
+
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
59 |
+
cnt = self._llm_model.chat("You're a helpful assistant.",
|
60 |
+
[{"role": "user", "content": self._prompt.format(cluster_content=cluster_content)}],
|
61 |
+
{"temperature": 0.3, "max_tokens": self._max_token}
|
62 |
+
)
|
63 |
+
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", cnt)
|
64 |
+
print("SUM:", cnt)
|
65 |
+
embds, _ = self._embd_model.encode([cnt])
|
66 |
+
with lock:
|
67 |
+
chunks.append((cnt, embds[0]))
|
68 |
+
except Exception as e:
|
69 |
+
print(e, flush=True)
|
70 |
+
traceback.print_stack(e)
|
71 |
+
return e
|
72 |
+
|
73 |
+
labels = []
|
74 |
+
while end - start > 1:
|
75 |
+
embeddings = [embd for _, embd in chunks[start: end]]
|
76 |
+
if len(embeddings) == 2:
|
77 |
+
summarize([start, start+1], Lock())
|
78 |
+
if callback:
|
79 |
+
callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end))
|
80 |
+
labels.extend([0,0])
|
81 |
+
layers.append((end, len(chunks)))
|
82 |
+
start = end
|
83 |
+
end = len(chunks)
|
84 |
+
continue
|
85 |
+
|
86 |
+
n_neighbors = int((len(embeddings) - 1) ** 0.8)
|
87 |
+
reduced_embeddings = umap.UMAP(
|
88 |
+
n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings)-2), metric="cosine"
|
89 |
+
).fit_transform(embeddings)
|
90 |
+
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
|
91 |
+
if n_clusters == 1:
|
92 |
+
lbls = [0 for _ in range(len(reduced_embeddings))]
|
93 |
+
else:
|
94 |
+
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
|
95 |
+
gm.fit(reduced_embeddings)
|
96 |
+
probs = gm.predict_proba(reduced_embeddings)
|
97 |
+
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
98 |
+
lock = Lock()
|
99 |
+
with ThreadPoolExecutor(max_workers=12) as executor:
|
100 |
+
threads = []
|
101 |
+
for c in range(n_clusters):
|
102 |
+
ck_idx = [i+start for i in range(len(lbls)) if lbls[i] == c]
|
103 |
+
threads.append(executor.submit(summarize, ck_idx, lock))
|
104 |
+
wait(threads, return_when=ALL_COMPLETED)
|
105 |
+
print([t.result() for t in threads])
|
106 |
+
|
107 |
+
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
108 |
+
labels.extend(lbls)
|
109 |
+
layers.append((end, len(chunks)))
|
110 |
+
if callback:
|
111 |
+
callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end))
|
112 |
+
start = end
|
113 |
+
end = len(chunks)
|
114 |
+
|
rag/svr/task_executor.py
CHANGED
@@ -26,20 +26,22 @@ import traceback
|
|
26 |
from functools import partial
|
27 |
|
28 |
from api.db.services.file2document_service import File2DocumentService
|
|
|
|
|
29 |
from rag.utils.minio_conn import MINIO
|
30 |
from api.db.db_models import close_connection
|
31 |
from rag.settings import database_logger, SVR_QUEUE_NAME
|
32 |
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
33 |
from multiprocessing import Pool
|
34 |
import numpy as np
|
35 |
-
from elasticsearch_dsl import Q
|
36 |
from multiprocessing.context import TimeoutError
|
37 |
from api.db.services.task_service import TaskService
|
38 |
from rag.utils.es_conn import ELASTICSEARCH
|
39 |
from timeit import default_timer as timer
|
40 |
-
from rag.utils import rmSpace, findMaxTm
|
41 |
|
42 |
-
from rag.nlp import search
|
43 |
from io import BytesIO
|
44 |
import pandas as pd
|
45 |
|
@@ -114,6 +116,8 @@ def collect():
|
|
114 |
tasks = TaskService.get_tasks(msg["id"])
|
115 |
assert tasks, "{} empty task!".format(msg["id"])
|
116 |
tasks = pd.DataFrame(tasks)
|
|
|
|
|
117 |
return tasks
|
118 |
|
119 |
|
@@ -245,6 +249,47 @@ def embedding(docs, mdl, parser_config={}, callback=None):
|
|
245 |
return tk_count
|
246 |
|
247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
def main():
|
249 |
rows = collect()
|
250 |
if len(rows) == 0:
|
@@ -259,35 +304,45 @@ def main():
|
|
259 |
cron_logger.error(str(e))
|
260 |
continue
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
-
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
285 |
init_kb(r)
|
286 |
chunk_count = len(set([c["_id"] for c in cks]))
|
287 |
st = timer()
|
288 |
es_r = ""
|
289 |
-
|
290 |
-
|
|
|
291 |
if b % 128 == 0:
|
292 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
293 |
|
|
|
26 |
from functools import partial
|
27 |
|
28 |
from api.db.services.file2document_service import File2DocumentService
|
29 |
+
from api.settings import retrievaler
|
30 |
+
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
31 |
from rag.utils.minio_conn import MINIO
|
32 |
from api.db.db_models import close_connection
|
33 |
from rag.settings import database_logger, SVR_QUEUE_NAME
|
34 |
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
35 |
from multiprocessing import Pool
|
36 |
import numpy as np
|
37 |
+
from elasticsearch_dsl import Q, Search
|
38 |
from multiprocessing.context import TimeoutError
|
39 |
from api.db.services.task_service import TaskService
|
40 |
from rag.utils.es_conn import ELASTICSEARCH
|
41 |
from timeit import default_timer as timer
|
42 |
+
from rag.utils import rmSpace, findMaxTm, num_tokens_from_string
|
43 |
|
44 |
+
from rag.nlp import search, rag_tokenizer
|
45 |
from io import BytesIO
|
46 |
import pandas as pd
|
47 |
|
|
|
116 |
tasks = TaskService.get_tasks(msg["id"])
|
117 |
assert tasks, "{} empty task!".format(msg["id"])
|
118 |
tasks = pd.DataFrame(tasks)
|
119 |
+
if msg.get("type", "") == "raptor":
|
120 |
+
tasks["task_type"] = "raptor"
|
121 |
return tasks
|
122 |
|
123 |
|
|
|
249 |
return tk_count
|
250 |
|
251 |
|
252 |
+
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
253 |
+
vts, _ = embd_mdl.encode(["ok"])
|
254 |
+
vctr_nm = "q_%d_vec"%len(vts[0])
|
255 |
+
chunks = []
|
256 |
+
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
|
257 |
+
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
258 |
+
|
259 |
+
raptor = Raptor(
|
260 |
+
row["parser_config"]["raptor"].get("max_cluster", 64),
|
261 |
+
chat_mdl,
|
262 |
+
embd_mdl,
|
263 |
+
row["parser_config"]["raptor"]["prompt"],
|
264 |
+
row["parser_config"]["raptor"]["max_token"],
|
265 |
+
row["parser_config"]["raptor"]["threshold"]
|
266 |
+
)
|
267 |
+
original_length = len(chunks)
|
268 |
+
raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
269 |
+
doc = {
|
270 |
+
"doc_id": row["doc_id"],
|
271 |
+
"kb_id": [str(row["kb_id"])],
|
272 |
+
"docnm_kwd": row["name"],
|
273 |
+
"title_tks": rag_tokenizer.tokenize(row["name"])
|
274 |
+
}
|
275 |
+
res = []
|
276 |
+
tk_count = 0
|
277 |
+
for content, vctr in chunks[original_length:]:
|
278 |
+
d = copy.deepcopy(doc)
|
279 |
+
md5 = hashlib.md5()
|
280 |
+
md5.update((content + str(d["doc_id"])).encode("utf-8"))
|
281 |
+
d["_id"] = md5.hexdigest()
|
282 |
+
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
283 |
+
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
284 |
+
d[vctr_nm] = vctr.tolist()
|
285 |
+
d["content_with_weight"] = content
|
286 |
+
d["content_ltks"] = rag_tokenizer.tokenize(content)
|
287 |
+
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
288 |
+
res.append(d)
|
289 |
+
tk_count += num_tokens_from_string(content)
|
290 |
+
return res, tk_count
|
291 |
+
|
292 |
+
|
293 |
def main():
|
294 |
rows = collect()
|
295 |
if len(rows) == 0:
|
|
|
304 |
cron_logger.error(str(e))
|
305 |
continue
|
306 |
|
307 |
+
if r.get("task_type", "") == "raptor":
|
308 |
+
try:
|
309 |
+
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
|
310 |
+
cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback)
|
311 |
+
except Exception as e:
|
312 |
+
callback(-1, msg=str(e))
|
313 |
+
cron_logger.error(str(e))
|
314 |
+
continue
|
315 |
+
else:
|
316 |
+
st = timer()
|
317 |
+
cks = build(r)
|
318 |
+
cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
|
319 |
+
if cks is None:
|
320 |
+
continue
|
321 |
+
if not cks:
|
322 |
+
callback(1., "No chunk! Done!")
|
323 |
+
continue
|
324 |
+
# TODO: exception handler
|
325 |
+
## set_progress(r["did"], -1, "ERROR: ")
|
326 |
+
callback(
|
327 |
+
msg="Finished slicing files(%d). Start to embedding the content." %
|
328 |
+
len(cks))
|
329 |
+
st = timer()
|
330 |
+
try:
|
331 |
+
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
332 |
+
except Exception as e:
|
333 |
+
callback(-1, "Embedding error:{}".format(str(e)))
|
334 |
+
cron_logger.error(str(e))
|
335 |
+
tk_count = 0
|
336 |
+
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
337 |
+
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
338 |
|
|
|
339 |
init_kb(r)
|
340 |
chunk_count = len(set([c["_id"] for c in cks]))
|
341 |
st = timer()
|
342 |
es_r = ""
|
343 |
+
es_bulk_size = 16
|
344 |
+
for b in range(0, len(cks), es_bulk_size):
|
345 |
+
es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]))
|
346 |
if b % 128 == 0:
|
347 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
348 |
|
rag/utils/redis_conn.py
CHANGED
@@ -97,15 +97,17 @@ class RedisDB:
|
|
97 |
return False
|
98 |
|
99 |
def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool:
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
return False
|
110 |
|
111 |
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> Payload:
|
|
|
97 |
return False
|
98 |
|
99 |
def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool:
|
100 |
+
for _ in range(3):
|
101 |
+
try:
|
102 |
+
payload = {"message": json.dumps(message)}
|
103 |
+
pipeline = self.REDIS.pipeline()
|
104 |
+
pipeline.xadd(queue, payload)
|
105 |
+
pipeline.expire(queue, exp)
|
106 |
+
pipeline.execute()
|
107 |
+
return True
|
108 |
+
except Exception as e:
|
109 |
+
print(e)
|
110 |
+
logging.warning("[EXCEPTION]producer" + str(queue) + "||" + str(e))
|
111 |
return False
|
112 |
|
113 |
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> Payload:
|