Spaces:
Running
Running
import asyncio | |
import datetime | |
from typing import Dict, List | |
from fastapi import WebSocket | |
from backend.report_type import BasicReport, DetailedReport | |
from backend.chat import ChatAgentWithMemory | |
from gpt_researcher.utils.enum import ReportType, Tone | |
from multi_agents.main import run_research_task | |
from gpt_researcher.actions import stream_output # Import stream_output | |
from backend.server.server_utils import CustomLogsHandler | |
class WebSocketManager: | |
"""Manage websockets""" | |
def __init__(self): | |
"""Initialize the WebSocketManager class.""" | |
self.active_connections: List[WebSocket] = [] | |
self.sender_tasks: Dict[WebSocket, asyncio.Task] = {} | |
self.message_queues: Dict[WebSocket, asyncio.Queue] = {} | |
self.chat_agent = None | |
async def start_sender(self, websocket: WebSocket): | |
"""Start the sender task.""" | |
queue = self.message_queues.get(websocket) | |
if not queue: | |
return | |
while True: | |
message = await queue.get() | |
if websocket in self.active_connections: | |
try: | |
if message == "ping": | |
await websocket.send_text("pong") | |
else: | |
await websocket.send_text(message) | |
except: | |
break | |
else: | |
break | |
async def connect(self, websocket: WebSocket): | |
"""Connect a websocket.""" | |
await websocket.accept() | |
self.active_connections.append(websocket) | |
self.message_queues[websocket] = asyncio.Queue() | |
self.sender_tasks[websocket] = asyncio.create_task( | |
self.start_sender(websocket)) | |
async def disconnect(self, websocket: WebSocket): | |
"""Disconnect a websocket.""" | |
if websocket in self.active_connections: | |
self.active_connections.remove(websocket) | |
self.sender_tasks[websocket].cancel() | |
await self.message_queues[websocket].put(None) | |
del self.sender_tasks[websocket] | |
del self.message_queues[websocket] | |
async def start_streaming(self, task, report_type, report_source, source_urls, document_urls, tone, websocket, headers=None): | |
"""Start streaming the output.""" | |
tone = Tone[tone] | |
# add customized JSON config file path here | |
config_path = "default" | |
report = await run_agent(task, report_type, report_source, source_urls, document_urls, tone, websocket, headers = headers, config_path = config_path) | |
#Create new Chat Agent whenever a new report is written | |
self.chat_agent = ChatAgentWithMemory(report, config_path, headers) | |
return report | |
async def chat(self, message, websocket): | |
"""Chat with the agent based message diff""" | |
if self.chat_agent: | |
await self.chat_agent.chat(message, websocket) | |
else: | |
await websocket.send_json({"type": "chat", "content": "Knowledge empty, please run the research first to obtain knowledge"}) | |
async def run_agent(task, report_type, report_source, source_urls, document_urls, tone: Tone, websocket, headers=None, config_path=""): | |
"""Run the agent.""" | |
start_time = datetime.datetime.now() | |
# Create logs handler for this research task | |
logs_handler = CustomLogsHandler(websocket, task) | |
# Initialize researcher based on report type | |
if report_type == "multi_agents": | |
report = await run_research_task( | |
query=task, | |
websocket=logs_handler, # Use logs_handler instead of raw websocket | |
stream_output=stream_output, | |
tone=tone, | |
headers=headers | |
) | |
report = report.get("report", "") | |
elif report_type == ReportType.DetailedReport.value: | |
researcher = DetailedReport( | |
query=task, | |
report_type=report_type, | |
report_source=report_source, | |
source_urls=source_urls, | |
document_urls=document_urls, | |
tone=tone, | |
config_path=config_path, | |
websocket=logs_handler, # Use logs_handler instead of raw websocket | |
headers=headers | |
) | |
report = await researcher.run() | |
else: | |
researcher = BasicReport( | |
query=task, | |
report_type=report_type, | |
report_source=report_source, | |
source_urls=source_urls, | |
document_urls=document_urls, | |
tone=tone, | |
config_path=config_path, | |
websocket=logs_handler, # Use logs_handler instead of raw websocket | |
headers=headers | |
) | |
report = await researcher.run() | |
return report | |