File size: 5,403 Bytes
2436df2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bc2fc9
2436df2
 
 
 
 
 
 
22fe41e
2436df2
 
 
0d756a3
2436df2
 
 
 
 
 
 
9c8f077
2436df2
 
 
 
 
 
 
 
 
 
9c8f077
2436df2
 
0404a52
 
ab87187
2436df2
 
 
 
 
9c8f077
2436df2
 
9c8f077
 
 
 
 
 
8bc2fc9
2436df2
 
0404a52
 
2436df2
 
8bc2fc9
2436df2
 
 
 
 
 
9c8f077
2436df2
9c8f077
 
2436df2
 
 
 
 
 
 
9c8f077
2436df2
 
 
 
 
 
 
 
 
75a07ce
2436df2
 
 
 
9c8f077
2436df2
 
8bc2fc9
2436df2
 
 
 
 
9c8f077
2436df2
 
 
9c8f077
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
import logging
import re
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
from threading import Lock
import umap
import numpy as np
from sklearn.mixture import GaussianMixture

from rag.utils import truncate


class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
    def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1):
        self._max_cluster = max_cluster
        self._llm_model = llm_model
        self._embd_model = embd_model
        self._threshold = threshold
        self._prompt = prompt
        self._max_token = max_token

    def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):
        max_clusters = min(self._max_cluster, len(embeddings))
        n_clusters = np.arange(1, max_clusters)
        bics = []
        for n in n_clusters:
            gm = GaussianMixture(n_components=n, random_state=random_state)
            gm.fit(embeddings)
            bics.append(gm.bic(embeddings))
        optimal_clusters = n_clusters[np.argmin(bics)]
        return optimal_clusters

    def __call__(self, chunks, random_state, callback=None):
        layers = [(0, len(chunks))]
        start, end = 0, len(chunks)
        if len(chunks) <= 1:
            return
        chunks = [(s, a) for s, a in chunks if len(a) > 0]

        def summarize(ck_idx, lock):
            nonlocal chunks
            try:
                texts = [chunks[i][0] for i in ck_idx]
                len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
                cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
                cnt = self._llm_model.chat("You're a helpful assistant.",
                                           [{"role": "user",
                                             "content": self._prompt.format(cluster_content=cluster_content)}],
                                           {"temperature": 0.3, "max_tokens": self._max_token}
                                           )
                cnt = re.sub("(路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵|For the content length reason, it stopped, continue?)", "",
                             cnt)
                logging.debug(f"SUM: {cnt}")
                embds, _ = self._embd_model.encode([cnt])
                with lock:
                    if not len(embds[0]):
                        return
                    chunks.append((cnt, embds[0]))
            except Exception as e:
                logging.exception("summarize got exception")
                return e

        labels = []
        while end - start > 1:
            embeddings = [embd for _, embd in chunks[start: end]]
            if len(embeddings) == 2:
                summarize([start, start + 1], Lock())
                if callback:
                    callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
                labels.extend([0, 0])
                layers.append((end, len(chunks)))
                start = end
                end = len(chunks)
                continue

            n_neighbors = int((len(embeddings) - 1) ** 0.8)
            reduced_embeddings = umap.UMAP(
                n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings) - 2), metric="cosine"
            ).fit_transform(embeddings)
            n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
            if n_clusters == 1:
                lbls = [0 for _ in range(len(reduced_embeddings))]
            else:
                gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
                gm.fit(reduced_embeddings)
                probs = gm.predict_proba(reduced_embeddings)
                lbls = [np.where(prob > self._threshold)[0] for prob in probs]
                lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
            lock = Lock()
            with ThreadPoolExecutor(max_workers=12) as executor:
                threads = []
                for c in range(n_clusters):
                    ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
                    threads.append(executor.submit(summarize, ck_idx, lock))
                wait(threads, return_when=ALL_COMPLETED)
                logging.debug(str([t.result() for t in threads]))

            assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
            labels.extend(lbls)
            layers.append((end, len(chunks)))
            if callback:
                callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
            start = end
            end = len(chunks)

        return chunks