Spaces:
Running
Running
import json | |
import os | |
from typing import Dict, List | |
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect, File, UploadFile, Header | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from pydantic import BaseModel | |
from backend.server.websocket_manager import WebSocketManager | |
from backend.server.server_utils import ( | |
get_config_dict, | |
update_environment_variables, handle_file_upload, handle_file_deletion, | |
execute_multi_agents, handle_websocket_communication | |
) | |
from gpt_researcher.utils.logging_config import setup_research_logging | |
import logging | |
# Get logger instance | |
logger = logging.getLogger(__name__) | |
# Don't override parent logger settings | |
logger.propagate = True | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
handlers=[ | |
logging.StreamHandler() # Only log to console | |
] | |
) | |
# Models | |
class ResearchRequest(BaseModel): | |
task: str | |
report_type: str | |
agent: str | |
class ConfigRequest(BaseModel): | |
ANTHROPIC_API_KEY: str | |
TAVILY_API_KEY: str | |
LANGCHAIN_TRACING_V2: str | |
LANGCHAIN_API_KEY: str | |
OPENAI_API_KEY: str | |
DOC_PATH: str | |
RETRIEVER: str | |
GOOGLE_API_KEY: str = '' | |
GOOGLE_CX_KEY: str = '' | |
BING_API_KEY: str = '' | |
SEARCHAPI_API_KEY: str = '' | |
SERPAPI_API_KEY: str = '' | |
SERPER_API_KEY: str = '' | |
SEARX_URL: str = '' | |
XAI_API_KEY: str | |
DEEPSEEK_API_KEY: str | |
# App initialization | |
app = FastAPI() | |
# Static files and templates | |
app.mount("/site", StaticFiles(directory="./frontend"), name="site") | |
app.mount("/static", StaticFiles(directory="./frontend/static"), name="static") | |
templates = Jinja2Templates(directory="./frontend") | |
# WebSocket manager | |
manager = WebSocketManager() | |
# Middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["http://localhost:3000"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Constants | |
DOC_PATH = os.getenv("DOC_PATH", "./my-docs") | |
# Startup event | |
def startup_event(): | |
os.makedirs("outputs", exist_ok=True) | |
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") | |
os.makedirs(DOC_PATH, exist_ok=True) | |
# Routes | |
async def read_root(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request, "report": None}) | |
async def list_files(): | |
files = os.listdir(DOC_PATH) | |
print(f"Files in {DOC_PATH}: {files}") | |
return {"files": files} | |
async def run_multi_agents(): | |
return await execute_multi_agents(manager) | |
async def upload_file(file: UploadFile = File(...)): | |
return await handle_file_upload(file, DOC_PATH) | |
async def delete_file(filename: str): | |
return await handle_file_deletion(filename, DOC_PATH) | |
async def websocket_endpoint(websocket: WebSocket): | |
await manager.connect(websocket) | |
try: | |
await handle_websocket_communication(websocket, manager) | |
except WebSocketDisconnect: | |
await manager.disconnect(websocket) | |