codebertBase / app.py
Forrest99's picture
Update app.py
338753c verified
raw
history blame
1.78 kB
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os
import logging
from pathlib import Path
# === 初始化日志 ===
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# === 检查缓存目录权限 ===
def check_permissions():
cache_path = Path(os.getenv("HF_HOME", ""))
try:
cache_path.mkdir(parents=True, exist_ok=True)
test_file = cache_path / "permission_test.txt"
test_file.write_text("test")
test_file.unlink()
logger.info(f"✅ 缓存目录权限正常: {cache_path}")
except Exception as e:
logger.error(f"❌ 缓存目录权限异常: {str(e)}")
raise RuntimeError(f"Directory permission error: {str(e)}")
check_permissions()
# === FastAPI 配置 ===
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# === 模型加载 ===
try:
logger.info("🔄 加载模型中...")
model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
logger.info("✅ 模型加载成功")
except Exception as e:
logger.error(f"❌ 模型加载失败: {str(e)}")
raise
# === API 接口 ===
@app.post("/detect")
async def detect(code: str):
inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
return {
"label": int(outputs.logits.argmax()),
"score": outputs.logits.softmax(dim=-1).max().item()
}