Forrest99 commited on
Commit
0a27391
·
verified ·
1 Parent(s): 338753c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -36
app.py CHANGED
@@ -1,31 +1,14 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
4
  import os
5
  import logging
6
- from pathlib import Path
7
 
8
- # === 初始化日志 ===
9
- logging.basicConfig(level=logging.INFO)
10
- logger = logging.getLogger(__name__)
11
-
12
- # === 检查缓存目录权限 ===
13
- def check_permissions():
14
- cache_path = Path(os.getenv("HF_HOME", ""))
15
- try:
16
- cache_path.mkdir(parents=True, exist_ok=True)
17
- test_file = cache_path / "permission_test.txt"
18
- test_file.write_text("test")
19
- test_file.unlink()
20
- logger.info(f"✅ 缓存目录权限正常: {cache_path}")
21
- except Exception as e:
22
- logger.error(f"❌ 缓存目录权限异常: {str(e)}")
23
- raise RuntimeError(f"Directory permission error: {str(e)}")
24
 
25
- check_permissions()
26
-
27
- # === FastAPI 配置 ===
28
- app = FastAPI()
29
  app.add_middleware(
30
  CORSMiddleware,
31
  allow_origins=["*"],
@@ -33,23 +16,70 @@ app.add_middleware(
33
  allow_headers=["*"],
34
  )
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # === 模型加载 ===
37
  try:
38
- logger.info("🔄 加载模型中...")
39
- model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
40
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
41
- logger.info("✅ 模型加载成功")
 
 
 
 
 
 
42
  except Exception as e:
43
- logger.error(f" 模型加载失败: {str(e)}")
44
- raise
45
 
46
- # === API 接口 ===
47
  @app.post("/detect")
48
- async def detect(code: str):
49
- inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512)
50
- with torch.no_grad():
51
- outputs = model(**inputs)
52
- return {
53
- "label": int(outputs.logits.argmax()),
54
- "score": outputs.logits.softmax(dim=-1).max().item()
55
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
  import os
6
  import logging
 
7
 
8
+ # === 初始化配置 ===
9
+ app = FastAPI(title="Code Security API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # 解决跨域问题
 
 
 
12
  app.add_middleware(
13
  CORSMiddleware,
14
  allow_origins=["*"],
 
16
  allow_headers=["*"],
17
  )
18
 
19
+ # === 强制设置缓存路径 ===
20
+ os.environ["HF_HOME"] = "/app/.cache/huggingface"
21
+ cache_path = os.getenv("HF_HOME")
22
+ os.makedirs(cache_path, exist_ok=True)
23
+
24
+ # === 日志配置 ===
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger("CodeBERT-API")
27
+
28
+ # === 根路径路由(必须定义)===
29
+ @app.get("/")
30
+ async def read_root():
31
+ """健康检查端点"""
32
+ return {
33
+ "status": "running",
34
+ "endpoints": {
35
+ "detect": "POST /detect - 代码安全检测",
36
+ "specs": "GET /openapi.json - API文档"
37
+ }
38
+ }
39
+
40
  # === 模型加载 ===
41
  try:
42
+ logger.info("Loading model from: %s", cache_path)
43
+ model = AutoModelForSequenceClassification.from_pretrained(
44
+ "mrm8488/codebert-base-finetuned-detect-insecure-code",
45
+ cache_dir=cache_path
46
+ )
47
+ tokenizer = AutoTokenizer.from_pretrained(
48
+ "mrm8488/codebert-base-finetuned-detect-insecure-code",
49
+ cache_dir=cache_path
50
+ )
51
+ logger.info("Model loaded successfully")
52
  except Exception as e:
53
+ logger.error("Model load failed: %s", str(e))
54
+ raise RuntimeError("模型初始化失败")
55
 
56
+ # === 核心检测接口 ===
57
  @app.post("/detect")
58
+ async def detect_vulnerability(code: str):
59
+ """代码安全检测主接口"""
60
+ try:
61
+ # 输入处理
62
+ code = code[:2000] # 截断超长输入
63
+
64
+ # 模型推理
65
+ inputs = tokenizer(
66
+ code,
67
+ return_tensors="pt",
68
+ truncation=True,
69
+ max_length=512
70
+ )
71
+ with torch.no_grad():
72
+ outputs = model(**inputs)
73
+
74
+ # 结果解析
75
+ label_id = outputs.logits.argmax().item()
76
+ return {
77
+ "label": label_id, # 0:安全 1:不安全
78
+ "confidence": outputs.logits.softmax(dim=-1)[0][label_id].item()
79
+ }
80
+
81
+ except Exception as e:
82
+ return {
83
+ "error": str(e),
84
+ "tip": "请检查输入代码是否包含非ASCII字符"
85
+ }