Spaces:
Running
on
T4
Running
on
T4
| from typing import List, Dict, Any, Optional | |
| from qdrant_client.http import models as rest | |
| from langchain.schema import Document | |
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
| from langchain.retrievers.document_compressors import CrossEncoderReranker | |
| from sentence_transformers import SentenceTransformer | |
| model = SentenceTransformer('BAAI/bge-m3') | |
| import logging | |
| import os | |
| from .utils import getconfig | |
| from .vectorstore_interface import create_vectorstore, VectorStoreInterface, QdrantVectorStore | |
| import sys | |
| # Configure logging to be more verbose | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| # Load configuration | |
| config = getconfig("params.cfg") | |
| # Retriever settings from config | |
| RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K")) | |
| SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD")) | |
| # Reranker settings from config | |
| RERANKER_ENABLED = config.getboolean("reranker", "ENABLED", fallback=False) | |
| RERANKER_MODEL = config.get("reranker", "MODEL_NAME", fallback="cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| RERANKER_TOP_K = int(config.get("reranker", "TOP_K", fallback=5)) | |
| RERANKER_TOP_K_SCALE_FACTOR = int(config.get("reranker", "TOP_K_SCALE_FACTOR", fallback=2)) | |
| # Initialize reranker if enabled | |
| reranker = None | |
| if RERANKER_ENABLED: | |
| try: | |
| print(f"Starting reranker initialization with model: {RERANKER_MODEL}", flush=True) | |
| logging.info(f"Initializing reranker with model: {RERANKER_MODEL}") | |
| print("Loading HuggingFace cross encoder model", flush=True) | |
| # HuggingFaceCrossEncoder doesn't accept cache_dir parameter | |
| # The underlying models will use default cache locations | |
| cross_encoder_model = HuggingFaceCrossEncoder(model_name=RERANKER_MODEL) | |
| print("Cross encoder model loaded successfully", flush=True) | |
| print("Creating CrossEncoderReranker...", flush=True) | |
| reranker = CrossEncoderReranker(model=cross_encoder_model, top_n=RERANKER_TOP_K) | |
| print("Reranker initialized successfully", flush=True) | |
| logging.info("Reranker initialized successfully") | |
| except Exception as e: | |
| print(f"Failed to initialize reranker: {str(e)}", flush=True) | |
| logging.error(f"Failed to initialize reranker: {str(e)}") | |
| reranker = None | |
| else: | |
| print("Reranker is disabled", flush=True) | |
| def get_vectorstore() -> VectorStoreInterface: | |
| """ | |
| Create and return a vector store connection. | |
| Returns: | |
| VectorStoreInterface instance | |
| """ | |
| logging.info("Initializing vector store connection...") | |
| vectorstore = create_vectorstore(config) | |
| logging.info("Vector store connection initialized successfully") | |
| return vectorstore | |
| def create_filter( | |
| filter_metadata:dict = None, | |
| ) -> Optional[rest.Filter]: | |
| """ | |
| Create a Qdrant filter based on metadata criteria. | |
| Args: | |
| reports: List of specific report filenames to filter by | |
| sources: Source type to filter by | |
| subtype: Document subtype to filter by | |
| year: List of years to filter by | |
| Returns: | |
| Qdrant Filter object or None if no filters specified | |
| """ | |
| if filter_metadata == None: | |
| return None | |
| conditions = [] | |
| logging.info(f"Defining filters for {filter_metadata}") | |
| for key, val in filter_metadata.items(): | |
| if isinstance(val, str): | |
| conditions.append(rest.FieldCondition( | |
| key=f"metadata.{key}", | |
| match=rest.MatchValue(value=val) | |
| ) | |
| ) | |
| else: | |
| conditions.append( | |
| rest.FieldCondition( | |
| key=f"metadata.{key}", | |
| match=rest.MatchAny(any=val) | |
| ) | |
| ) | |
| filter = rest.Filter( | |
| must = conditions | |
| ) | |
| return filter | |
| def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """ | |
| Rerank documents using cross-encoder (specify in params.cfg) | |
| Args: | |
| query: The search query | |
| documents: List of documents to rerank | |
| Returns: | |
| Reranked list of documents in original format | |
| """ | |
| if not reranker or not documents: | |
| return documents | |
| try: | |
| logging.info(f"Starting reranking of {len(documents)} documents") | |
| # Convert to LangChain Document format using correct keys (need to review this later for portability) | |
| langchain_docs = [] | |
| for doc in documents: | |
| # Use correct keys from the data storage test module | |
| content = doc.get('answer', '') | |
| metadata = doc.get('answer_metadata', {}) | |
| if not content: | |
| logging.warning(f"Document missing content: {doc}") | |
| continue | |
| langchain_doc = Document( | |
| page_content=content, | |
| metadata=metadata | |
| ) | |
| langchain_docs.append(langchain_doc) | |
| if not langchain_docs: | |
| logging.warning("No valid documents found for reranking") | |
| return documents | |
| # Rerank documents | |
| logging.info(f"Reranking {len(langchain_docs)} documents") | |
| reranked_docs = reranker.compress_documents(langchain_docs, query) | |
| # Convert back to original format | |
| result = [] | |
| for doc in reranked_docs: | |
| result.append({ | |
| 'answer': doc.page_content, | |
| 'answer_metadata': doc.metadata, | |
| }) | |
| logging.info(f"Successfully reranked {len(documents)} documents to top {len(result)}") | |
| return result | |
| except Exception as e: | |
| logging.error(f"Error during reranking: {str(e)}") | |
| # Return original documents if reranking fails | |
| return documents | |
| def get_context( | |
| vectorstore: VectorStoreInterface, | |
| query: str, | |
| collection_name: str = None, | |
| filter_metadata = None, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve semantically similar documents from the vector database with optional reranking. | |
| Args: | |
| vectorstore: The vector store interface to search | |
| query: The search query | |
| reports: List of specific report filenames to search within | |
| sources: Source type to filter by | |
| subtype: Document subtype to filter by | |
| year: List of years to filter by | |
| Returns: | |
| List of dictionaries with 'answer', 'answer_metadata', and 'score' keys | |
| """ | |
| try: | |
| # Use a higher k for initial retrieval if reranking is enabled (more candidates docs) | |
| top_k = RETRIEVER_TOP_K | |
| if RERANKER_ENABLED and reranker: | |
| top_k = top_k * RERANKER_TOP_K_SCALE_FACTOR | |
| logging.info(f"Reranking enabled, retrieving {top_k} candidates") | |
| search_kwargs = { | |
| "model_name": config.get("embeddings", "MODEL_NAME") | |
| } | |
| #model = SentenceTransformer(config.get("embeddings", "MODEL_NAME")) | |
| #query_vector = model.encode(query).tolist() | |
| #retrieved_docs = vectorstore.search( | |
| ## collection_name="EUDR", | |
| # query_vector=query_vector, | |
| # limit=top_k, | |
| # with_payload=True) | |
| # filter support for QdrantVectorStore | |
| if isinstance(vectorstore, QdrantVectorStore): | |
| print(filter_metadata) | |
| filter_obj = create_filter(filter_metadata) | |
| if filter_obj: | |
| search_kwargs["filter"] = filter_obj | |
| # Perform initial retrieval | |
| print(search_kwargs) | |
| retrieved_docs = vectorstore.search(query, collection_name, top_k, **search_kwargs) | |
| logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...") | |
| # Apply reranking if enabled | |
| if RERANKER_ENABLED and reranker and retrieved_docs: | |
| logging.info("Applying reranking...") | |
| retrieved_docs = rerank_documents(query, retrieved_docs) | |
| # Trim to final desired k | |
| retrieved_docs = retrieved_docs[:RERANKER_TOP_K] | |
| logging.info(f"Returning {len(retrieved_docs)} final documents") | |
| logging.info(f"Retrieved results: {retrieved_docs}") | |
| return retrieved_docs | |
| except Exception as e: | |
| logging.error(f"Error during retrieval: {str(e)}") | |
| raise e |