File size: 11,747 Bytes
fe5c39d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
"""RAG schemas."""
from enum import Enum
from pathlib import Path
from typing import Any, ClassVar, List, Literal, Optional, Union

from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator

from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.logs import logger
from metagpt.rag.interface import RAGObject


class BaseRetrieverConfig(BaseModel):
    """Common config for retrievers.

    If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")


class IndexRetrieverConfig(BaseRetrieverConfig):
    """Config for Index-basd retrievers."""

    index: BaseIndex = Field(default=None, description="Index for retriver.")


class FAISSRetrieverConfig(IndexRetrieverConfig):
    """Config for FAISS-based retrievers."""

    dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.")

    _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
        EmbeddingType.GEMINI: 768,
        EmbeddingType.OLLAMA: 4096,
    }

    @model_validator(mode="after")
    def check_dimensions(self):
        if self.dimensions == 0:
            self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
                config.embedding.api_type, 1536
            )
            if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
                logger.warning(
                    f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
                )

        return self


class BM25RetrieverConfig(IndexRetrieverConfig):
    """Config for BM25-based retrievers."""

    _no_embedding: bool = PrivateAttr(default=True)


class MilvusRetrieverConfig(IndexRetrieverConfig):
    """Config for Milvus-based retrievers."""

    uri: str = Field(default="./milvus_local.db", description="The directory to save data.")
    collection_name: str = Field(default="metagpt", description="The name of the collection.")
    token: str = Field(default=None, description="The token for Milvus")
    metadata: Optional[CollectionMetadata] = Field(
        default=None, description="Optional metadata to associate with the collection"
    )
    dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.")

    _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
        EmbeddingType.GEMINI: 768,
        EmbeddingType.OLLAMA: 4096,
    }

    @model_validator(mode="after")
    def check_dimensions(self):
        if self.dimensions == 0:
            self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
                config.embedding.api_type, 1536
            )
            if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
                logger.warning(
                    f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
                )

        return self


class ChromaRetrieverConfig(IndexRetrieverConfig):
    """Config for Chroma-based retrievers."""

    persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
    collection_name: str = Field(default="metagpt", description="The name of the collection.")
    metadata: Optional[CollectionMetadata] = Field(
        default=None, description="Optional metadata to associate with the collection"
    )


class ElasticsearchStoreConfig(BaseModel):
    index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.")
    es_url: str = Field(default=None, description="Elasticsearch URL.")
    es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.")
    es_api_key: str = Field(default=None, description="Elasticsearch API key.")
    es_user: str = Field(default=None, description="Elasticsearch username.")
    es_password: str = Field(default=None, description="Elasticsearch password.")
    batch_size: int = Field(default=200, description="Batch size for bulk indexing.")
    distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.")


class ElasticsearchRetrieverConfig(IndexRetrieverConfig):
    """Config for Elasticsearch-based retrievers. Support both vector and text."""

    store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
    vector_store_query_mode: VectorStoreQueryMode = Field(
        default=VectorStoreQueryMode.DEFAULT, description="default is vector query."
    )


class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig):
    """Config for Elasticsearch-based retrievers. Support text only."""

    _no_embedding: bool = PrivateAttr(default=True)
    vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field(
        default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only."
    )


class BaseRankerConfig(BaseModel):
    """Common config for rankers.

    If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    top_n: int = Field(default=5, description="The number of top results to return.")


class LLMRankerConfig(BaseRankerConfig):
    """Config for LLM-based rankers."""

    llm: Any = Field(
        default=None,
        description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.",
    )


class ColbertRerankConfig(BaseRankerConfig):
    model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.")
    device: str = Field(default="cpu", description="Device to use for sentence transformer.")
    keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")


class CohereRerankConfig(BaseRankerConfig):
    model: str = Field(default="rerank-english-v3.0")
    api_key: str = Field(default="YOUR_COHERE_API")


class BGERerankConfig(BaseRankerConfig):
    model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.")
    use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.")


class ObjectRankerConfig(BaseRankerConfig):
    field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
    order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")


class BaseIndexConfig(BaseModel):
    """Common config for index.

    If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    persist_path: Union[str, Path] = Field(description="The directory of saved data.")


class VectorIndexConfig(BaseIndexConfig):
    """Config for vector-based index."""

    embed_model: BaseEmbedding = Field(default=None, description="Embed model.")


class FAISSIndexConfig(VectorIndexConfig):
    """Config for faiss-based index."""


class ChromaIndexConfig(VectorIndexConfig):
    """Config for chroma-based index."""

    collection_name: str = Field(default="metagpt", description="The name of the collection.")
    metadata: Optional[CollectionMetadata] = Field(
        default=None, description="Optional metadata to associate with the collection"
    )


class MilvusIndexConfig(VectorIndexConfig):
    """Config for milvus-based index."""

    collection_name: str = Field(default="metagpt", description="The name of the collection.")
    uri: str = Field(default="./milvus_local.db", description="The uri of the index.")
    token: Optional[str] = Field(default=None, description="The token of the index.")
    metadata: Optional[CollectionMetadata] = Field(
        default=None, description="Optional metadata to associate with the collection"
    )


class BM25IndexConfig(BaseIndexConfig):
    """Config for bm25-based index."""

    _no_embedding: bool = PrivateAttr(default=True)


class ElasticsearchIndexConfig(VectorIndexConfig):
    """Config for es-based index."""

    store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
    persist_path: Union[str, Path] = ""


class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig):
    """Config for es-based index. no embedding."""

    _no_embedding: bool = PrivateAttr(default=True)


class ObjectNodeMetadata(BaseModel):
    """Metadata of ObjectNode."""

    is_obj: bool = Field(default=True)
    obj: Any = Field(default=None, description="When rag retrieve, will reconstruct obj from obj_json")
    obj_json: str = Field(..., description="The json of object, e.g. obj.model_dump_json()")
    obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
    obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")


class ObjectNode(TextNode):
    """RAG add object."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
        self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys

    @staticmethod
    def get_obj_metadata(obj: RAGObject) -> dict:
        metadata = ObjectNodeMetadata(
            obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
        )

        return metadata.model_dump()


class OmniParseType(str, Enum):
    """OmniParseType"""

    PDF = "PDF"
    DOCUMENT = "DOCUMENT"


class ParseResultType(str, Enum):
    """The result type for the parser."""

    TXT = "text"
    MD = "markdown"
    JSON = "json"


class OmniParseOptions(BaseModel):
    """OmniParse Options config"""

    result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
    parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
    max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests")
    num_workers: int = Field(
        default=5,
        gt=0,
        lt=10,
        description="Number of concurrent requests for multiple files",
    )


class OminParseImage(BaseModel):
    image: str = Field(default="", description="image str bytes")
    image_name: str = Field(default="", description="image name")
    image_info: Optional[dict] = Field(default={}, description="image info")


class OmniParsedResult(BaseModel):
    markdown: str = Field(default="", description="markdown text")
    text: str = Field(default="", description="plain text")
    images: Optional[List[OminParseImage]] = Field(default=[], description="images")
    metadata: Optional[dict] = Field(default={}, description="metadata")

    @model_validator(mode="before")
    def set_markdown(cls, values):
        if not values.get("markdown"):
            values["markdown"] = values.get("text")
        return values