Spaces:
Runtime error
Runtime error
| import os | |
| import asyncio | |
| from dotenv import load_dotenv | |
| import gradio as gr | |
| from query_utils import process_query_for_rewrite, get_non_autism_response | |
| # helper functions | |
| GEMINI_API_KEY="AIzaSyCUCivstFpC9pq_jMHMYdlPrmh9Bx97dFo" | |
| TAVILY_API_KEY="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" | |
| OPENAI_API_KEY="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" | |
| QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwiZXhwIjoxNzUxMDUxNzg4fQ.I9J-K7OM0BtcNKgj2d4uVM8QYAHYfFCVAyP4rlZkK2E" | |
| QDRANT_URL="https://6a3aade6-e8ad-4a6c-a579-21f5af90b7e8.us-east4-0.gcp.cloud.qdrant.io" | |
| OPENAI_API_KEY="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" | |
| WEAVIATE_URL="yorcqe2sqswhcaivxvt9a.c0.us-west3.gcp.weaviate.cloud" | |
| WEAVIATE_API_KEY="d2d0VGdZQTBmdTFlOWdDZl9tT2h3WDVWd1NpT1dQWHdGK0xjR1hYeWxicUxHVnFRazRUSjY2VlRUVlkwPV92MjAw" | |
| DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4" | |
| DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai" | |
| # if not (DEEPINFRA_TOKEN and WEAVIATE_URL and WEAVIATE_API_KEY): | |
| # raise ValueError("Please set all required keys in .env") | |
| # DeepInfra client | |
| from openai import OpenAI | |
| openai = OpenAI( | |
| api_key=DEEPINFRA_API_KEY, | |
| base_url="https://api.deepinfra.com/v1/openai", | |
| ) | |
| # Weaviate client | |
| import weaviate | |
| from weaviate.classes.init import Auth | |
| from contextlib import contextmanager | |
| def weaviate_client(): | |
| client = weaviate.connect_to_weaviate_cloud( | |
| cluster_url=WEAVIATE_URL, | |
| auth_credentials=Auth.api_key(WEAVIATE_API_KEY), | |
| skip_init_checks=True, # <-- This disables gRPC check | |
| ) | |
| try: | |
| yield client | |
| finally: | |
| client.close() | |
| # Global path tracker | |
| last_uploaded_path = None | |
| # Embed function | |
| def embed_texts(texts: list[str], batch_size: int = 50) -> list[list[float]]: | |
| all_embeddings = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i : i + batch_size] | |
| try: | |
| resp = openai.embeddings.create( | |
| model="Qwen/Qwen3-Embedding-8B", | |
| input=batch, | |
| encoding_format="float" | |
| ) | |
| batch_embs = [item.embedding for item in resp.data] | |
| all_embeddings.extend(batch_embs) | |
| except Exception as e: | |
| print(f"Embedding error: {e}") | |
| all_embeddings.extend([[] for _ in batch]) | |
| return all_embeddings | |
| def encode_query(query: str) -> list[float] | None: | |
| embs = embed_texts([query], batch_size=1) | |
| if embs and embs[0]: | |
| return embs[0] | |
| return None | |
| async def old_Document(query: str, top_k: int = 1) -> dict: | |
| qe = encode_query(query) | |
| if not qe: | |
| return {"answer": []} | |
| try: | |
| with weaviate_client() as client: | |
| coll = client.collections.get("user") | |
| res = coll.query.near_vector( | |
| near_vector=qe, | |
| limit=top_k, | |
| return_properties=["text"] | |
| ) | |
| if not getattr(res, "objects", None): | |
| return {"answer": []} | |
| return { | |
| "answer": [obj.properties.get("text", "[No Text]") for obj in res.objects] | |
| } | |
| except Exception as e: | |
| print("RAG Error:", e) | |
| return {"answer": []} | |
| # New functions to support Gradio app | |
| def ingest_file(path: str) -> str: | |
| global last_uploaded_path | |
| last_uploaded_path = path | |
| return f"Old document ingested: {os.path.basename(path)}" | |
| def answer_question(query: str) -> str: | |
| try: | |
| # Process query for rewriting and relevance checking | |
| corrected_query, is_autism_related, rewritten_query = process_query_for_rewrite(query) | |
| # If not autism-related, show direct rejection message | |
| if not is_autism_related: | |
| return get_non_autism_response() | |
| # Use the corrected query for retrieval | |
| rag_resp = asyncio.run(old_Document(corrected_query)) | |
| chunks = rag_resp.get("answer", []) | |
| if not chunks: | |
| return "Sorry, I couldn't find relevant content in the old document." | |
| # Combine chunks into a single answer for relevance checking | |
| combined_answer = "\n".join(f"- {c}" for c in chunks) | |
| # NEW: Check if the retrieved content is sufficiently related to autism | |
| from query_utils import check_answer_autism_relevance, get_non_autism_answer_response | |
| answer_relevance_score = check_answer_autism_relevance(combined_answer) | |
| # If answer relevance is below 50%, refuse the answer (updated threshold for enhanced scoring) | |
| if answer_relevance_score < 50: | |
| return get_non_autism_answer_response() | |
| # If sufficiently autism-related, return the answer | |
| return combined_answer | |
| except Exception as e: | |
| return f"Error processing your request: {e}" | |
| # Gradio interface for Old Documents | |
| with gr.Blocks(title="Old Documents RAG") as demo: | |
| gr.Markdown("## Old Documents RAG") | |
| query = gr.Textbox(placeholder="Your question...", lines=2, label="Ask about Old Documents") | |
| doc_file = gr.File(label="Upload Old Document (PDF, DOCX, TXT)") | |
| btn = gr.Button("Submit") | |
| out = gr.Textbox(label="Answer from Old Documents", lines=8, interactive=False) | |
| def process_old_doc(query, doc_file): | |
| if doc_file: | |
| # Save and ingest the uploaded file | |
| upload_dir = os.path.join(os.path.dirname(__file__), "uploaded_docs") | |
| os.makedirs(upload_dir, exist_ok=True) | |
| safe_filename = os.path.basename(doc_file.name) | |
| save_path = os.path.join(upload_dir, safe_filename) | |
| with open(save_path, "wb") as f: | |
| f.write(doc_file.read()) | |
| status = ingest_file(save_path) | |
| answer = answer_question(query) | |
| return f"{status}\n\n{answer}" | |
| else: | |
| # Use last uploaded file or return error if none exists | |
| if last_uploaded_path: | |
| answer = answer_question(query) | |
| return f"[Using previously uploaded document: {os.path.basename(last_uploaded_path)}]\n\n{answer}" | |
| else: | |
| return "No document uploaded. Please upload an old document to proceed." | |
| btn.click(fn=process_old_doc, inputs=[query, doc_file], outputs=out) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |