"""RAG schemas.""" from enum import Enum from pathlib import Path from typing import Any, ClassVar, List, Literal, Optional, Union from chromadb.api.types import CollectionMetadata from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from metagpt.config2 import config from metagpt.configs.embedding_config import EmbeddingType from metagpt.logs import logger from metagpt.rag.interface import RAGObject class BaseRetrieverConfig(BaseModel): """Common config for retrievers. If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever. """ model_config = ConfigDict(arbitrary_types_allowed=True) similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.") class IndexRetrieverConfig(BaseRetrieverConfig): """Config for Index-basd retrievers.""" index: BaseIndex = Field(default=None, description="Index for retriver.") class FAISSRetrieverConfig(IndexRetrieverConfig): """Config for FAISS-based retrievers.""" dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.") _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { EmbeddingType.GEMINI: 768, EmbeddingType.OLLAMA: 4096, } @model_validator(mode="after") def check_dimensions(self): if self.dimensions == 0: self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( config.embedding.api_type, 1536 ) if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions: logger.warning( f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536" ) return self class BM25RetrieverConfig(IndexRetrieverConfig): """Config for BM25-based retrievers.""" _no_embedding: bool = PrivateAttr(default=True) class MilvusRetrieverConfig(IndexRetrieverConfig): """Config for Milvus-based retrievers.""" uri: str = Field(default="./milvus_local.db", description="The directory to save data.") collection_name: str = Field(default="metagpt", description="The name of the collection.") token: str = Field(default=None, description="The token for Milvus") metadata: Optional[CollectionMetadata] = Field( default=None, description="Optional metadata to associate with the collection" ) dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.") _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { EmbeddingType.GEMINI: 768, EmbeddingType.OLLAMA: 4096, } @model_validator(mode="after") def check_dimensions(self): if self.dimensions == 0: self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( config.embedding.api_type, 1536 ) if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions: logger.warning( f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536" ) return self class ChromaRetrieverConfig(IndexRetrieverConfig): """Config for Chroma-based retrievers.""" persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.") collection_name: str = Field(default="metagpt", description="The name of the collection.") metadata: Optional[CollectionMetadata] = Field( default=None, description="Optional metadata to associate with the collection" ) class ElasticsearchStoreConfig(BaseModel): index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.") es_url: str = Field(default=None, description="Elasticsearch URL.") es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.") es_api_key: str = Field(default=None, description="Elasticsearch API key.") es_user: str = Field(default=None, description="Elasticsearch username.") es_password: str = Field(default=None, description="Elasticsearch password.") batch_size: int = Field(default=200, description="Batch size for bulk indexing.") distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.") class ElasticsearchRetrieverConfig(IndexRetrieverConfig): """Config for Elasticsearch-based retrievers. Support both vector and text.""" store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") vector_store_query_mode: VectorStoreQueryMode = Field( default=VectorStoreQueryMode.DEFAULT, description="default is vector query." ) class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig): """Config for Elasticsearch-based retrievers. Support text only.""" _no_embedding: bool = PrivateAttr(default=True) vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field( default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only." ) class BaseRankerConfig(BaseModel): """Common config for rankers. If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker. """ model_config = ConfigDict(arbitrary_types_allowed=True) top_n: int = Field(default=5, description="The number of top results to return.") class LLMRankerConfig(BaseRankerConfig): """Config for LLM-based rankers.""" llm: Any = Field( default=None, description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.", ) class ColbertRerankConfig(BaseRankerConfig): model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.") device: str = Field(default="cpu", description="Device to use for sentence transformer.") keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.") class CohereRerankConfig(BaseRankerConfig): model: str = Field(default="rerank-english-v3.0") api_key: str = Field(default="YOUR_COHERE_API") class BGERerankConfig(BaseRankerConfig): model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.") use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.") class ObjectRankerConfig(BaseRankerConfig): field_name: str = Field(..., description="field name of the object, field's value must can be compared.") order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.") class BaseIndexConfig(BaseModel): """Common config for index. If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index. """ model_config = ConfigDict(arbitrary_types_allowed=True) persist_path: Union[str, Path] = Field(description="The directory of saved data.") class VectorIndexConfig(BaseIndexConfig): """Config for vector-based index.""" embed_model: BaseEmbedding = Field(default=None, description="Embed model.") class FAISSIndexConfig(VectorIndexConfig): """Config for faiss-based index.""" class ChromaIndexConfig(VectorIndexConfig): """Config for chroma-based index.""" collection_name: str = Field(default="metagpt", description="The name of the collection.") metadata: Optional[CollectionMetadata] = Field( default=None, description="Optional metadata to associate with the collection" ) class MilvusIndexConfig(VectorIndexConfig): """Config for milvus-based index.""" collection_name: str = Field(default="metagpt", description="The name of the collection.") uri: str = Field(default="./milvus_local.db", description="The uri of the index.") token: Optional[str] = Field(default=None, description="The token of the index.") metadata: Optional[CollectionMetadata] = Field( default=None, description="Optional metadata to associate with the collection" ) class BM25IndexConfig(BaseIndexConfig): """Config for bm25-based index.""" _no_embedding: bool = PrivateAttr(default=True) class ElasticsearchIndexConfig(VectorIndexConfig): """Config for es-based index.""" store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") persist_path: Union[str, Path] = "" class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig): """Config for es-based index. no embedding.""" _no_embedding: bool = PrivateAttr(default=True) class ObjectNodeMetadata(BaseModel): """Metadata of ObjectNode.""" is_obj: bool = Field(default=True) obj: Any = Field(default=None, description="When rag retrieve, will reconstruct obj from obj_json") obj_json: str = Field(..., description="The json of object, e.g. obj.model_dump_json()") obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__") obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__") class ObjectNode(TextNode): """RAG add object.""" def __init__(self, **kwargs): super().__init__(**kwargs) self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys()) self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys @staticmethod def get_obj_metadata(obj: RAGObject) -> dict: metadata = ObjectNodeMetadata( obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__ ) return metadata.model_dump() class OmniParseType(str, Enum): """OmniParseType""" PDF = "PDF" DOCUMENT = "DOCUMENT" class ParseResultType(str, Enum): """The result type for the parser.""" TXT = "text" MD = "markdown" JSON = "json" class OmniParseOptions(BaseModel): """OmniParse Options config""" result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type") parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type") max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests") num_workers: int = Field( default=5, gt=0, lt=10, description="Number of concurrent requests for multiple files", ) class OminParseImage(BaseModel): image: str = Field(default="", description="image str bytes") image_name: str = Field(default="", description="image name") image_info: Optional[dict] = Field(default={}, description="image info") class OmniParsedResult(BaseModel): markdown: str = Field(default="", description="markdown text") text: str = Field(default="", description="plain text") images: Optional[List[OminParseImage]] = Field(default=[], description="images") metadata: Optional[dict] = Field(default={}, description="metadata") @model_validator(mode="before") def set_markdown(cls, values): if not values.get("markdown"): values["markdown"] = values.get("text") return values