Spaces:
Paused
Paused
| import multiprocessing | |
| import json | |
| import os | |
| import uvicorn | |
| from fastapi import FastAPI, Request, HTTPException, Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from utils import extract_and_cache_document, service, cache_file_popup_url, cache_root, cache_file, code_interpreter_ws, update_pop_url, change_checkbox_state | |
| from starlette.middleware.sessions import SessionMiddleware | |
| # os.environ["TRANSFORMERS_CACHE"] = ".cache/huggingface/" | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| # allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=['*'], | |
| allow_headers=['*'], | |
| ) | |
| app.mount('/static', StaticFiles(directory=code_interpreter_ws), name='static') | |
| async def access_token_auth(request: Request, call_next): | |
| # print(f"Request URL path: {request.url}") | |
| access_token: str = request.headers.get("Authorization") or request.query_params.get("access_token") or request.session.get("access_token") | |
| is_valid = False | |
| if access_token: | |
| account_info = json.loads(service.get(access_token, "info.json", False)) | |
| if account_info and account_info["enabled"]: | |
| is_valid = True | |
| if not is_valid: | |
| return Response(status_code=401, content="the token is not valid") | |
| request.session.setdefault("access_token", access_token) | |
| return await call_next(request) | |
| async def healthz(request: Request): | |
| return JSONResponse({"healthz": True}) | |
| async def add_token(request: Request): | |
| access_token: str = request.headers.get("Authorization") or request.query_params.get("access_token") or request.session.get("access_token") | |
| account_info = json.loads(service.get(access_token, "info.json", False)) | |
| if account_info and account_info["enabled"] and account_info["role"] == 'admin': | |
| return Response(status_code=401, content="the token is not valid") | |
| data = await request.json() | |
| service.upsert(access_token, "info.json", json.dumps(data, ensure_ascii=False), False) | |
| return JSONResponse({"success": True}) | |
| async def cache_data(request: Request, file_name: str): | |
| access_token: str = request.headers.get("Authorization") or request.query_params.get("access_token") or request.session.get("access_token") | |
| account_info = json.loads(service.get(access_token, "info.json", False)) | |
| if account_info and account_info["enabled"] and account_info["role"] == 'admin': | |
| return Response(status_code=401, content="the token is not valid") | |
| data = service.get(access_token, file_name, False) | |
| content = json.loads(data) if data else "" | |
| return JSONResponse(content) | |
| async def web_listening(request: Request): | |
| data = await request.json() | |
| msg_type = data['task'] | |
| access_token = request.session.get("access_token") | |
| if msg_type == 'change_checkbox': | |
| rsp = change_checkbox_state(data['ckid'], cache_file, access_token) | |
| elif msg_type == 'cache': | |
| cache_obj = multiprocessing.Process( target=extract_and_cache_document, args=(data, cache_root, access_token)) | |
| cache_obj.start() | |
| # rsp = cache_data(data, cache_file) | |
| rsp = 'caching' | |
| elif msg_type == 'pop_url': | |
| # What a misleading name! pop_url actually means add_url. pop is referring to the pop_up ui. | |
| rsp = update_pop_url(data, cache_file_popup_url, access_token) | |
| else: | |
| raise NotImplementedError | |
| return JSONResponse(content=rsp) | |
| import gradio as gr | |
| from assistant_server import demo as assistant_app | |
| from workstation_server import demo as workstation_app | |
| app = gr.mount_gradio_app(app, assistant_app, path="/assistant") | |
| app = gr.mount_gradio_app(app, workstation_app, path="/workstation") | |
| app.add_middleware(SessionMiddleware, secret_key=os.getenv("SECRET_KEY"), max_age=25200) | |
| if __name__ == '__main__': | |
| uvicorn.run(app='database_server:app', host='0.0.0.0', port=7860, reload=False, workers=1) | |