SPO / metagpt /rag /schema.py
XiangJinYu's picture
add metagpt
fe5c39d verified
"""RAG schemas."""
from enum import Enum
from pathlib import Path
from typing import Any, ClassVar, List, Literal, Optional, Union
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.logs import logger
from metagpt.rag.interface import RAGObject
class BaseRetrieverConfig(BaseModel):
"""Common config for retrievers.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")
class IndexRetrieverConfig(BaseRetrieverConfig):
"""Config for Index-basd retrievers."""
index: BaseIndex = Field(default=None, description="Index for retriver.")
class FAISSRetrieverConfig(IndexRetrieverConfig):
"""Config for FAISS-based retrievers."""
dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.")
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
EmbeddingType.GEMINI: 768,
EmbeddingType.OLLAMA: 4096,
}
@model_validator(mode="after")
def check_dimensions(self):
if self.dimensions == 0:
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
config.embedding.api_type, 1536
)
if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
logger.warning(
f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
)
return self
class BM25RetrieverConfig(IndexRetrieverConfig):
"""Config for BM25-based retrievers."""
_no_embedding: bool = PrivateAttr(default=True)
class MilvusRetrieverConfig(IndexRetrieverConfig):
"""Config for Milvus-based retrievers."""
uri: str = Field(default="./milvus_local.db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
token: str = Field(default=None, description="The token for Milvus")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.")
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
EmbeddingType.GEMINI: 768,
EmbeddingType.OLLAMA: 4096,
}
@model_validator(mode="after")
def check_dimensions(self):
if self.dimensions == 0:
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
config.embedding.api_type, 1536
)
if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
logger.warning(
f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
)
return self
class ChromaRetrieverConfig(IndexRetrieverConfig):
"""Config for Chroma-based retrievers."""
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class ElasticsearchStoreConfig(BaseModel):
index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.")
es_url: str = Field(default=None, description="Elasticsearch URL.")
es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.")
es_api_key: str = Field(default=None, description="Elasticsearch API key.")
es_user: str = Field(default=None, description="Elasticsearch username.")
es_password: str = Field(default=None, description="Elasticsearch password.")
batch_size: int = Field(default=200, description="Batch size for bulk indexing.")
distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.")
class ElasticsearchRetrieverConfig(IndexRetrieverConfig):
"""Config for Elasticsearch-based retrievers. Support both vector and text."""
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
vector_store_query_mode: VectorStoreQueryMode = Field(
default=VectorStoreQueryMode.DEFAULT, description="default is vector query."
)
class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig):
"""Config for Elasticsearch-based retrievers. Support text only."""
_no_embedding: bool = PrivateAttr(default=True)
vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field(
default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only."
)
class BaseRankerConfig(BaseModel):
"""Common config for rankers.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
top_n: int = Field(default=5, description="The number of top results to return.")
class LLMRankerConfig(BaseRankerConfig):
"""Config for LLM-based rankers."""
llm: Any = Field(
default=None,
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.",
)
class ColbertRerankConfig(BaseRankerConfig):
model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.")
device: str = Field(default="cpu", description="Device to use for sentence transformer.")
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
class CohereRerankConfig(BaseRankerConfig):
model: str = Field(default="rerank-english-v3.0")
api_key: str = Field(default="YOUR_COHERE_API")
class BGERerankConfig(BaseRankerConfig):
model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.")
use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.")
class ObjectRankerConfig(BaseRankerConfig):
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
class BaseIndexConfig(BaseModel):
"""Common config for index.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
class VectorIndexConfig(BaseIndexConfig):
"""Config for vector-based index."""
embed_model: BaseEmbedding = Field(default=None, description="Embed model.")
class FAISSIndexConfig(VectorIndexConfig):
"""Config for faiss-based index."""
class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class MilvusIndexConfig(VectorIndexConfig):
"""Config for milvus-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
uri: str = Field(default="./milvus_local.db", description="The uri of the index.")
token: Optional[str] = Field(default=None, description="The token of the index.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class BM25IndexConfig(BaseIndexConfig):
"""Config for bm25-based index."""
_no_embedding: bool = PrivateAttr(default=True)
class ElasticsearchIndexConfig(VectorIndexConfig):
"""Config for es-based index."""
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
persist_path: Union[str, Path] = ""
class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig):
"""Config for es-based index. no embedding."""
_no_embedding: bool = PrivateAttr(default=True)
class ObjectNodeMetadata(BaseModel):
"""Metadata of ObjectNode."""
is_obj: bool = Field(default=True)
obj: Any = Field(default=None, description="When rag retrieve, will reconstruct obj from obj_json")
obj_json: str = Field(..., description="The json of object, e.g. obj.model_dump_json()")
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
class ObjectNode(TextNode):
"""RAG add object."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys
@staticmethod
def get_obj_metadata(obj: RAGObject) -> dict:
metadata = ObjectNodeMetadata(
obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
)
return metadata.model_dump()
class OmniParseType(str, Enum):
"""OmniParseType"""
PDF = "PDF"
DOCUMENT = "DOCUMENT"
class ParseResultType(str, Enum):
"""The result type for the parser."""
TXT = "text"
MD = "markdown"
JSON = "json"
class OmniParseOptions(BaseModel):
"""OmniParse Options config"""
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests")
num_workers: int = Field(
default=5,
gt=0,
lt=10,
description="Number of concurrent requests for multiple files",
)
class OminParseImage(BaseModel):
image: str = Field(default="", description="image str bytes")
image_name: str = Field(default="", description="image name")
image_info: Optional[dict] = Field(default={}, description="image info")
class OmniParsedResult(BaseModel):
markdown: str = Field(default="", description="markdown text")
text: str = Field(default="", description="plain text")
images: Optional[List[OminParseImage]] = Field(default=[], description="images")
metadata: Optional[dict] = Field(default={}, description="metadata")
@model_validator(mode="before")
def set_markdown(cls, values):
if not values.get("markdown"):
values["markdown"] = values.get("text")
return values