import json import os import re import time import shutil from typing import Dict, List, Any from fastapi.responses import JSONResponse, FileResponse from gpt_researcher.document.document import DocumentLoader from backend.utils import write_md_to_pdf, write_md_to_word, write_text_to_md from pathlib import Path from datetime import datetime from fastapi import HTTPException import logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) class CustomLogsHandler: """Custom handler to capture streaming logs from the research process""" def __init__(self, websocket, task: str): self.logs = [] self.websocket = websocket sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{task}") self.log_file = os.path.join("outputs", f"{sanitized_filename}.json") self.timestamp = datetime.now().isoformat() # Initialize log file with metadata os.makedirs("outputs", exist_ok=True) with open(self.log_file, 'w') as f: json.dump({ "timestamp": self.timestamp, "events": [], "content": { "query": "", "sources": [], "context": [], "report": "", "costs": 0.0 } }, f, indent=2) async def send_json(self, data: Dict[str, Any]) -> None: """Store log data and send to websocket""" # Send to websocket for real-time display if self.websocket: await self.websocket.send_json(data) # Read current log file with open(self.log_file, 'r') as f: log_data = json.load(f) # Update appropriate section based on data type if data.get('type') == 'logs': log_data['events'].append({ "timestamp": datetime.now().isoformat(), "type": "event", "data": data }) else: # Update content section for other types of data log_data['content'].update(data) # Save updated log file with open(self.log_file, 'w') as f: json.dump(log_data, f, indent=2) logger.debug(f"Log entry written to: {self.log_file}") class Researcher: def __init__(self, query: str, report_type: str = "research_report"): self.query = query self.report_type = report_type # Generate unique ID for this research task self.research_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{hash(query)}" # Initialize logs handler with research ID self.logs_handler = CustomLogsHandler(self.research_id) self.researcher = GPTResearcher( query=query, report_type=report_type, websocket=self.logs_handler ) async def research(self) -> dict: """Conduct research and return paths to generated files""" await self.researcher.conduct_research() report = await self.researcher.write_report() # Generate the files sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{self.query}") file_paths = await generate_report_files(report, sanitized_filename) # Get the JSON log path that was created by CustomLogsHandler json_relative_path = os.path.relpath(self.logs_handler.log_file) return { "output": { **file_paths, # Include PDF, DOCX, and MD paths "json": json_relative_path } } def sanitize_filename(filename: str) -> str: # Split into components prefix, timestamp, *task_parts = filename.split('_') task = '_'.join(task_parts) # Calculate max length for task portion # 255 - len("outputs/") - len("task_") - len(timestamp) - len("_.json") - safety_margin max_task_length = 255 - 8 - 5 - 10 - 6 - 10 # ~216 chars for task # Truncate task if needed truncated_task = task[:max_task_length] if len(task) > max_task_length else task # Reassemble and clean the filename sanitized = f"{prefix}_{timestamp}_{truncated_task}" return re.sub(r"[^\w\s-]", "", sanitized).strip() async def handle_start_command(websocket, data: str, manager): json_data = json.loads(data[6:]) task, report_type, source_urls, document_urls, tone, headers, report_source = extract_command_data( json_data) if not task or not report_type: print("Error: Missing task or report_type") return # Create logs handler with websocket and task logs_handler = CustomLogsHandler(websocket, task) # Initialize log content with query await logs_handler.send_json({ "query": task, "sources": [], "context": [], "report": "" }) sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{task}") report = await manager.start_streaming( task, report_type, report_source, source_urls, document_urls, tone, websocket, headers ) report = str(report) file_paths = await generate_report_files(report, sanitized_filename) # Add JSON log path to file_paths file_paths["json"] = os.path.relpath(logs_handler.log_file) await send_file_paths(websocket, file_paths) async def handle_human_feedback(data: str): feedback_data = json.loads(data[14:]) # Remove "human_feedback" prefix print(f"Received human feedback: {feedback_data}") # TODO: Add logic to forward the feedback to the appropriate agent or update the research state async def handle_chat(websocket, data: str, manager): json_data = json.loads(data[4:]) print(f"Received chat message: {json_data.get('message')}") await manager.chat(json_data.get("message"), websocket) async def generate_report_files(report: str, filename: str) -> Dict[str, str]: pdf_path = await write_md_to_pdf(report, filename) docx_path = await write_md_to_word(report, filename) md_path = await write_text_to_md(report, filename) return {"pdf": pdf_path, "docx": docx_path, "md": md_path} async def send_file_paths(websocket, file_paths: Dict[str, str]): await websocket.send_json({"type": "path", "output": file_paths}) def get_config_dict( langchain_api_key: str, openai_api_key: str, tavily_api_key: str, google_api_key: str, google_cx_key: str, bing_api_key: str, searchapi_api_key: str, serpapi_api_key: str, serper_api_key: str, searx_url: str ) -> Dict[str, str]: return { "LANGCHAIN_API_KEY": langchain_api_key or os.getenv("LANGCHAIN_API_KEY", ""), "OPENAI_API_KEY": openai_api_key or os.getenv("OPENAI_API_KEY", ""), "TAVILY_API_KEY": tavily_api_key or os.getenv("TAVILY_API_KEY", ""), "GOOGLE_API_KEY": google_api_key or os.getenv("GOOGLE_API_KEY", ""), "GOOGLE_CX_KEY": google_cx_key or os.getenv("GOOGLE_CX_KEY", ""), "BING_API_KEY": bing_api_key or os.getenv("BING_API_KEY", ""), "SEARCHAPI_API_KEY": searchapi_api_key or os.getenv("SEARCHAPI_API_KEY", ""), "SERPAPI_API_KEY": serpapi_api_key or os.getenv("SERPAPI_API_KEY", ""), "SERPER_API_KEY": serper_api_key or os.getenv("SERPER_API_KEY", ""), "SEARX_URL": searx_url or os.getenv("SEARX_URL", ""), "LANGCHAIN_TRACING_V2": os.getenv("LANGCHAIN_TRACING_V2", "true"), "DOC_PATH": os.getenv("DOC_PATH", "./my-docs"), "RETRIEVER": os.getenv("RETRIEVER", ""), "EMBEDDING_MODEL": os.getenv("OPENAI_EMBEDDING_MODEL", "") } def update_environment_variables(config: Dict[str, str]): for key, value in config.items(): os.environ[key] = value async def handle_file_upload(file, DOC_PATH: str) -> Dict[str, str]: file_path = os.path.join(DOC_PATH, os.path.basename(file.filename)) with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) print(f"File uploaded to {file_path}") document_loader = DocumentLoader(DOC_PATH) await document_loader.load() return {"filename": file.filename, "path": file_path} async def handle_file_deletion(filename: str, DOC_PATH: str) -> JSONResponse: file_path = os.path.join(DOC_PATH, os.path.basename(filename)) if os.path.exists(file_path): os.remove(file_path) print(f"File deleted: {file_path}") return JSONResponse(content={"message": "File deleted successfully"}) else: print(f"File not found: {file_path}") return JSONResponse(status_code=404, content={"message": "File not found"}) async def execute_multi_agents(manager) -> Any: websocket = manager.active_connections[0] if manager.active_connections else None if websocket: report = await run_research_task("Is AI in a hype cycle?", websocket, stream_output) return {"report": report} else: return JSONResponse(status_code=400, content={"message": "No active WebSocket connection"}) async def handle_websocket_communication(websocket, manager): while True: data = await websocket.receive_text() if data.startswith("start"): await handle_start_command(websocket, data, manager) elif data.startswith("human_feedback"): await handle_human_feedback(data) elif data.startswith("chat"): await handle_chat(websocket, data, manager) else: print("Error: Unknown command or not enough parameters provided.") def extract_command_data(json_data: Dict) -> tuple: return ( json_data.get("task"), json_data.get("report_type"), json_data.get("source_urls"), json_data.get("document_urls"), json_data.get("tone"), json_data.get("headers", {}), json_data.get("report_source") )