GPT-Researcher / backend /server /websocket_manager.py
Shreyas094's picture
Upload 528 files
372531f verified
raw
history blame
4.84 kB
import asyncio
import datetime
from typing import Dict, List
from fastapi import WebSocket
from backend.report_type import BasicReport, DetailedReport
from backend.chat import ChatAgentWithMemory
from gpt_researcher.utils.enum import ReportType, Tone
from multi_agents.main import run_research_task
from gpt_researcher.actions import stream_output # Import stream_output
from backend.server.server_utils import CustomLogsHandler
class WebSocketManager:
"""Manage websockets"""
def __init__(self):
"""Initialize the WebSocketManager class."""
self.active_connections: List[WebSocket] = []
self.sender_tasks: Dict[WebSocket, asyncio.Task] = {}
self.message_queues: Dict[WebSocket, asyncio.Queue] = {}
self.chat_agent = None
async def start_sender(self, websocket: WebSocket):
"""Start the sender task."""
queue = self.message_queues.get(websocket)
if not queue:
return
while True:
message = await queue.get()
if websocket in self.active_connections:
try:
if message == "ping":
await websocket.send_text("pong")
else:
await websocket.send_text(message)
except:
break
else:
break
async def connect(self, websocket: WebSocket):
"""Connect a websocket."""
await websocket.accept()
self.active_connections.append(websocket)
self.message_queues[websocket] = asyncio.Queue()
self.sender_tasks[websocket] = asyncio.create_task(
self.start_sender(websocket))
async def disconnect(self, websocket: WebSocket):
"""Disconnect a websocket."""
if websocket in self.active_connections:
self.active_connections.remove(websocket)
self.sender_tasks[websocket].cancel()
await self.message_queues[websocket].put(None)
del self.sender_tasks[websocket]
del self.message_queues[websocket]
async def start_streaming(self, task, report_type, report_source, source_urls, document_urls, tone, websocket, headers=None):
"""Start streaming the output."""
tone = Tone[tone]
# add customized JSON config file path here
config_path = "default"
report = await run_agent(task, report_type, report_source, source_urls, document_urls, tone, websocket, headers = headers, config_path = config_path)
#Create new Chat Agent whenever a new report is written
self.chat_agent = ChatAgentWithMemory(report, config_path, headers)
return report
async def chat(self, message, websocket):
"""Chat with the agent based message diff"""
if self.chat_agent:
await self.chat_agent.chat(message, websocket)
else:
await websocket.send_json({"type": "chat", "content": "Knowledge empty, please run the research first to obtain knowledge"})
async def run_agent(task, report_type, report_source, source_urls, document_urls, tone: Tone, websocket, headers=None, config_path=""):
"""Run the agent."""
start_time = datetime.datetime.now()
# Create logs handler for this research task
logs_handler = CustomLogsHandler(websocket, task)
# Initialize researcher based on report type
if report_type == "multi_agents":
report = await run_research_task(
query=task,
websocket=logs_handler, # Use logs_handler instead of raw websocket
stream_output=stream_output,
tone=tone,
headers=headers
)
report = report.get("report", "")
elif report_type == ReportType.DetailedReport.value:
researcher = DetailedReport(
query=task,
report_type=report_type,
report_source=report_source,
source_urls=source_urls,
document_urls=document_urls,
tone=tone,
config_path=config_path,
websocket=logs_handler, # Use logs_handler instead of raw websocket
headers=headers
)
report = await researcher.run()
else:
researcher = BasicReport(
query=task,
report_type=report_type,
report_source=report_source,
source_urls=source_urls,
document_urls=document_urls,
tone=tone,
config_path=config_path,
websocket=logs_handler, # Use logs_handler instead of raw websocket
headers=headers
)
report = await researcher.run()
return report