import os
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
import asyncio
from contextlib import asynccontextmanager

from RequestModel import PredictRequest

# 全局变量,用于跟踪初始化状态
is_initialized = False
initialization_lock = asyncio.Lock()

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 启动时运行
    global is_initialized
    async with initialization_lock:
        if not is_initialized:
            await initialize_application()
            is_initialized = True
    yield
    # 关闭时运行
    # cleanup_code_here()

async def initialize_application():
    # 在这里进行所有需要的初始化
    from us_stock import fetch_symbols

    await fetch_symbols()
    # 其他初始化代码...

app = FastAPI(lifespan=lifespan)

# 添加 CORS 中间件和限流配置
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 添加信任主机中间件
app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=["*"]
)

# 定义请求模型
class TextRequest(BaseModel):
    text: str

# 定义两个 API 路由处理函数
@app.post("/api/aaa")
async def api_aaa_post(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}

# 定义两个 API 路由处理函数
@app.post("/aaa")
async def aaa(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}


# 定义两个 API 路由处理函数
@app.get("/aaa")
async def api_aaa_get(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}

@app.post("/api/bbb")
async def api_bbb(request: TextRequest):
    result = request.text + 'bbb'
    return {"result": result}

# 优化预测路由
@app.post("/api/predict")
async def predict(request: PredictRequest):
    from blkeras import predict
    try:
        result = await asyncio.to_thread(predict, request.text, request.stock_codes)
        return result
    except Exception as e:
        return []

@app.get("/")
async def root():
    return {"message": "Welcome to the API. Use /api/aaa or /api/bbb for processing."}