from fastapi import FastAPI, Body
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
import logging

# === 初始化配置 ===
app = FastAPI(title="Code Security API")

# 解决跨域问题
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# === 强制设置缓存路径 ===
os.environ["HF_HOME"] = "/app/.cache/huggingface"
cache_path = os.getenv("HF_HOME")
os.makedirs(cache_path, exist_ok=True)

# === 日志配置 ===
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("CodeBERT-API")

# === 根路径路由(必须定义)===
@app.get("/")
async def read_root():
    """健康检查端点"""
    return {
        "status": "running",
        "endpoints": {
            "detect": "POST /detect - 代码安全检测",
            "specs": "GET /openapi.json - API文档"
        }
    }

# === 模型加载 ===
try:
    logger.info("Loading model from: %s", cache_path)
    model = AutoModelForSequenceClassification.from_pretrained(
        "mrm8488/codebert-base-finetuned-detect-insecure-code",
        cache_dir=cache_path
    )
    tokenizer = AutoTokenizer.from_pretrained(
        "mrm8488/codebert-base-finetuned-detect-insecure-code",
        cache_dir=cache_path
    )
    logger.info("Model loaded successfully")
except Exception as e:
    logger.error("Model load failed: %s", str(e))
    raise RuntimeError("模型初始化失败")

# === 核心检测接口 ===
@app.post("/detect")
async def detect_vulnerability(payload: dict = Body(...)):
    """代码安全检测主接口"""
    try:
        # 获取 JSON 输入数据
        code = payload.get("code", "").strip()

        if not code:
            return {"error": "代码内容为空", "tip": "请提供有效的代码字符串"}

        # 限制代码长度
        code = code[:2000]  # 截断超长输入
        
        # 模型推理
        inputs = tokenizer(
            code,
            return_tensors="pt",
            truncation=True,
            padding=True,  # 自动选择填充策略
            max_length=512
        )

        with torch.no_grad():
            outputs = model(**inputs)

        # 结果解析
        logits = outputs.logits
        label_id = logits.argmax().item()
        confidence = logits.softmax(dim=-1)[0][label_id].item()

        logger.info(f"Code analyzed. Logits: {logits.tolist()}, Prediction: {label_id}, Confidence: {confidence:.4f}")

        return {
            "label": label_id,  # 0:安全 1:不安全
            "confidence": round(confidence, 4)
        }
        
    except Exception as e:
        logger.error("Error during model inference: %s", str(e))
        return {
            "error": str(e),
            "tip": "请检查输入代码是否包含非ASCII字符或格式错误"
        }