Spaces:
Running
Running
File size: 11,747 Bytes
fe5c39d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
"""RAG schemas."""
from enum import Enum
from pathlib import Path
from typing import Any, ClassVar, List, Literal, Optional, Union
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.logs import logger
from metagpt.rag.interface import RAGObject
class BaseRetrieverConfig(BaseModel):
"""Common config for retrievers.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")
class IndexRetrieverConfig(BaseRetrieverConfig):
"""Config for Index-basd retrievers."""
index: BaseIndex = Field(default=None, description="Index for retriver.")
class FAISSRetrieverConfig(IndexRetrieverConfig):
"""Config for FAISS-based retrievers."""
dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.")
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
EmbeddingType.GEMINI: 768,
EmbeddingType.OLLAMA: 4096,
}
@model_validator(mode="after")
def check_dimensions(self):
if self.dimensions == 0:
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
config.embedding.api_type, 1536
)
if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
logger.warning(
f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
)
return self
class BM25RetrieverConfig(IndexRetrieverConfig):
"""Config for BM25-based retrievers."""
_no_embedding: bool = PrivateAttr(default=True)
class MilvusRetrieverConfig(IndexRetrieverConfig):
"""Config for Milvus-based retrievers."""
uri: str = Field(default="./milvus_local.db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
token: str = Field(default=None, description="The token for Milvus")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.")
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
EmbeddingType.GEMINI: 768,
EmbeddingType.OLLAMA: 4096,
}
@model_validator(mode="after")
def check_dimensions(self):
if self.dimensions == 0:
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
config.embedding.api_type, 1536
)
if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
logger.warning(
f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
)
return self
class ChromaRetrieverConfig(IndexRetrieverConfig):
"""Config for Chroma-based retrievers."""
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class ElasticsearchStoreConfig(BaseModel):
index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.")
es_url: str = Field(default=None, description="Elasticsearch URL.")
es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.")
es_api_key: str = Field(default=None, description="Elasticsearch API key.")
es_user: str = Field(default=None, description="Elasticsearch username.")
es_password: str = Field(default=None, description="Elasticsearch password.")
batch_size: int = Field(default=200, description="Batch size for bulk indexing.")
distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.")
class ElasticsearchRetrieverConfig(IndexRetrieverConfig):
"""Config for Elasticsearch-based retrievers. Support both vector and text."""
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
vector_store_query_mode: VectorStoreQueryMode = Field(
default=VectorStoreQueryMode.DEFAULT, description="default is vector query."
)
class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig):
"""Config for Elasticsearch-based retrievers. Support text only."""
_no_embedding: bool = PrivateAttr(default=True)
vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field(
default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only."
)
class BaseRankerConfig(BaseModel):
"""Common config for rankers.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
top_n: int = Field(default=5, description="The number of top results to return.")
class LLMRankerConfig(BaseRankerConfig):
"""Config for LLM-based rankers."""
llm: Any = Field(
default=None,
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.",
)
class ColbertRerankConfig(BaseRankerConfig):
model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.")
device: str = Field(default="cpu", description="Device to use for sentence transformer.")
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
class CohereRerankConfig(BaseRankerConfig):
model: str = Field(default="rerank-english-v3.0")
api_key: str = Field(default="YOUR_COHERE_API")
class BGERerankConfig(BaseRankerConfig):
model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.")
use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.")
class ObjectRankerConfig(BaseRankerConfig):
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
class BaseIndexConfig(BaseModel):
"""Common config for index.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
class VectorIndexConfig(BaseIndexConfig):
"""Config for vector-based index."""
embed_model: BaseEmbedding = Field(default=None, description="Embed model.")
class FAISSIndexConfig(VectorIndexConfig):
"""Config for faiss-based index."""
class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class MilvusIndexConfig(VectorIndexConfig):
"""Config for milvus-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
uri: str = Field(default="./milvus_local.db", description="The uri of the index.")
token: Optional[str] = Field(default=None, description="The token of the index.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class BM25IndexConfig(BaseIndexConfig):
"""Config for bm25-based index."""
_no_embedding: bool = PrivateAttr(default=True)
class ElasticsearchIndexConfig(VectorIndexConfig):
"""Config for es-based index."""
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
persist_path: Union[str, Path] = ""
class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig):
"""Config for es-based index. no embedding."""
_no_embedding: bool = PrivateAttr(default=True)
class ObjectNodeMetadata(BaseModel):
"""Metadata of ObjectNode."""
is_obj: bool = Field(default=True)
obj: Any = Field(default=None, description="When rag retrieve, will reconstruct obj from obj_json")
obj_json: str = Field(..., description="The json of object, e.g. obj.model_dump_json()")
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
class ObjectNode(TextNode):
"""RAG add object."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys
@staticmethod
def get_obj_metadata(obj: RAGObject) -> dict:
metadata = ObjectNodeMetadata(
obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
)
return metadata.model_dump()
class OmniParseType(str, Enum):
"""OmniParseType"""
PDF = "PDF"
DOCUMENT = "DOCUMENT"
class ParseResultType(str, Enum):
"""The result type for the parser."""
TXT = "text"
MD = "markdown"
JSON = "json"
class OmniParseOptions(BaseModel):
"""OmniParse Options config"""
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests")
num_workers: int = Field(
default=5,
gt=0,
lt=10,
description="Number of concurrent requests for multiple files",
)
class OminParseImage(BaseModel):
image: str = Field(default="", description="image str bytes")
image_name: str = Field(default="", description="image name")
image_info: Optional[dict] = Field(default={}, description="image info")
class OmniParsedResult(BaseModel):
markdown: str = Field(default="", description="markdown text")
text: str = Field(default="", description="plain text")
images: Optional[List[OminParseImage]] = Field(default=[], description="images")
metadata: Optional[dict] = Field(default={}, description="metadata")
@model_validator(mode="before")
def set_markdown(cls, values):
if not values.get("markdown"):
values["markdown"] = values.get("text")
return values
|