File size: 6,418 Bytes
2436df2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bc2fc9
2436df2
 
 
 
 
 
 
758538f
22fe41e
2436df2
 
 
0d756a3
2436df2
 
 
 
 
 
 
758538f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c8f077
2436df2
 
 
 
 
 
 
 
 
 
9c8f077
2436df2
 
0404a52
 
ab87187
2436df2
 
 
 
 
9c8f077
2436df2
758538f
9c8f077
 
 
 
 
 
8bc2fc9
2436df2
 
758538f
2436df2
8bc2fc9
2436df2
 
 
 
 
 
9c8f077
2436df2
9c8f077
 
2436df2
 
 
 
 
 
 
9c8f077
2436df2
 
 
 
 
 
 
 
 
75a07ce
2436df2
 
 
 
9c8f077
2436df2
 
4ac524c
 
 
8bc2fc9
2436df2
 
 
 
 
9c8f077
2436df2
 
 
9c8f077
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
import logging
import re
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
from threading import Lock
import umap
import numpy as np
from sklearn.mixture import GaussianMixture

from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache
from rag.utils import truncate


class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
    def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1):
        self._max_cluster = max_cluster
        self._llm_model = llm_model
        self._embd_model = embd_model
        self._threshold = threshold
        self._prompt = prompt
        self._max_token = max_token

    def _chat(self, system, history, gen_conf):
        response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
        if response:
            return response
        response = self._llm_model.chat(system, history, gen_conf)
        if response.find("**ERROR**") >= 0:
            raise Exception(response)
        set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
        return response

    def _embedding_encode(self, txt):
        response = get_embed_cache(self._embd_model.llm_name, txt)
        if response:
            return response
        embds, _ = self._embd_model.encode([txt])
        if len(embds) < 1 or len(embds[0]) < 1:
            raise Exception("Embedding error: ")
        embds = embds[0]
        set_embed_cache(self._embd_model.llm_name, txt, embds)
        return embds

    def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):
        max_clusters = min(self._max_cluster, len(embeddings))
        n_clusters = np.arange(1, max_clusters)
        bics = []
        for n in n_clusters:
            gm = GaussianMixture(n_components=n, random_state=random_state)
            gm.fit(embeddings)
            bics.append(gm.bic(embeddings))
        optimal_clusters = n_clusters[np.argmin(bics)]
        return optimal_clusters

    def __call__(self, chunks, random_state, callback=None):
        layers = [(0, len(chunks))]
        start, end = 0, len(chunks)
        if len(chunks) <= 1:
            return
        chunks = [(s, a) for s, a in chunks if len(a) > 0]

        def summarize(ck_idx, lock):
            nonlocal chunks
            try:
                texts = [chunks[i][0] for i in ck_idx]
                len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
                cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
                cnt = self._chat("You're a helpful assistant.",
                                           [{"role": "user",
                                             "content": self._prompt.format(cluster_content=cluster_content)}],
                                           {"temperature": 0.3, "max_tokens": self._max_token}
                                           )
                cnt = re.sub("(路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵|For the content length reason, it stopped, continue?)", "",
                             cnt)
                logging.debug(f"SUM: {cnt}")
                embds, _ = self._embd_model.encode([cnt])
                with lock:
                    chunks.append((cnt, self._embedding_encode(cnt)))
            except Exception as e:
                logging.exception("summarize got exception")
                return e

        labels = []
        while end - start > 1:
            embeddings = [embd for _, embd in chunks[start: end]]
            if len(embeddings) == 2:
                summarize([start, start + 1], Lock())
                if callback:
                    callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
                labels.extend([0, 0])
                layers.append((end, len(chunks)))
                start = end
                end = len(chunks)
                continue

            n_neighbors = int((len(embeddings) - 1) ** 0.8)
            reduced_embeddings = umap.UMAP(
                n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings) - 2), metric="cosine"
            ).fit_transform(embeddings)
            n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
            if n_clusters == 1:
                lbls = [0 for _ in range(len(reduced_embeddings))]
            else:
                gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
                gm.fit(reduced_embeddings)
                probs = gm.predict_proba(reduced_embeddings)
                lbls = [np.where(prob > self._threshold)[0] for prob in probs]
                lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
            lock = Lock()
            with ThreadPoolExecutor(max_workers=12) as executor:
                threads = []
                for c in range(n_clusters):
                    ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
                    threads.append(executor.submit(summarize, ck_idx, lock))
                wait(threads, return_when=ALL_COMPLETED)
                for th in threads:
                    if isinstance(th.result(), Exception):
                        raise th.result()
                logging.debug(str([t.result() for t in threads]))

            assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
            labels.extend(lbls)
            layers.append((end, len(chunks)))
            if callback:
                callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
            start = end
            end = len(chunks)

        return chunks