Spaces:
Running
on
T4
Running
on
T4
| import gradio as gr | |
| import sys | |
| from utils.retriever import get_context, get_vectorstore | |
| # Initialize vector store at startup | |
| print("Initializing vector store connection...", flush=True) | |
| try: | |
| vectorstore = get_vectorstore() | |
| print("Vector store connection initialized successfully", flush=True) | |
| except Exception as e: | |
| print(f"Failed to initialize vector store: {e}", flush=True) | |
| raise | |
| # --------------------------------------------------------------------- | |
| # MCP - returns raw dictionary format | |
| # --------------------------------------------------------------------- | |
| def retrieve( | |
| query: str, | |
| reports_filter: str = "", | |
| sources_filter: str = "", | |
| subtype_filter: str = "", | |
| year_filter: str = "" | |
| ) -> list: | |
| """ | |
| Retrieve semantically similar documents from the vector database for MCP clients. | |
| Args: | |
| query (str): The search query text | |
| reports_filter (str): Comma-separated list of specific report filenames (optional) | |
| sources_filter (str): Filter by document source type (optional) | |
| subtype_filter (str): Filter by document subtype (optional) | |
| year_filter (str): Comma-separated list of years to filter by (optional) | |
| Returns: | |
| list: List of dictionaries containing document content, metadata, and scores | |
| """ | |
| # Parse filter inputs (convert empty strings to None or lists) | |
| reports = [r.strip() for r in reports_filter.split(",") if r.strip()] if reports_filter else [] | |
| sources = sources_filter.strip() if sources_filter else None | |
| subtype = subtype_filter.strip() if subtype_filter else None | |
| year = [y.strip() for y in year_filter.split(",") if y.strip()] if year_filter else None | |
| # Call retriever function and return raw results | |
| results = get_context( | |
| vectorstore=vectorstore, | |
| query=query, | |
| reports=reports, | |
| sources=sources, | |
| subtype=subtype, | |
| year=year | |
| ) | |
| return results | |
| # Create the Gradio interface with Blocks to support both UI and MCP | |
| with gr.Blocks() as ui: | |
| gr.Markdown("# ChatFed Retrieval/Reranker Module") | |
| gr.Markdown("Retrieves semantically similar documents from vector database and reranks. Intended for use in RAG pipelines as an MCP server with other ChatFed modules.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_input = gr.Textbox( | |
| label="Query", | |
| lines=2, | |
| placeholder="Enter your search query here", | |
| info="The query to search for in the vector database" | |
| ) | |
| reports_input = gr.Textbox( | |
| label="Reports Filter (optional)", | |
| lines=1, | |
| placeholder="report1.pdf, report2.pdf", | |
| info="Comma-separated list of specific report filenames to search within (leave empty for all)" | |
| ) | |
| sources_input = gr.Textbox( | |
| label="Sources Filter (optional)", | |
| lines=1, | |
| placeholder="annual_report", | |
| info="Filter by document source type (leave empty for all)" | |
| ) | |
| subtype_input = gr.Textbox( | |
| label="Subtype Filter (optional)", | |
| lines=1, | |
| placeholder="financial", | |
| info="Filter by document subtype (leave empty for all)" | |
| ) | |
| year_input = gr.Textbox( | |
| label="Year Filter (optional)", | |
| lines=1, | |
| placeholder="2023, 2024", | |
| info="Comma-separated list of years to filter by (leave empty for all)" | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| # Output needs to be in json format to be added as tool in HuggingChat | |
| with gr.Column(): | |
| output = gr.Text( | |
| label="Retrieved Context", | |
| lines=10, | |
| show_copy_button=True | |
| ) | |
| # UI event handler | |
| submit_btn.click( | |
| fn=retrieve, | |
| inputs=[query_input, reports_input, sources_input, subtype_input, year_input], | |
| outputs=output, | |
| api_name="retrieve" | |
| ) | |
| # Launch with MCP server enabled | |
| if __name__ == "__main__": | |
| ui.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| #mcp_server=True, | |
| show_error=True | |
| ) |