Forrest99 commited on
Commit
d37c72d
·
verified ·
1 Parent(s): 867bc1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
@@ -55,10 +55,16 @@ except Exception as e:
55
 
56
  # === 核心检测接口 ===
57
  @app.post("/detect")
58
- async def detect_vulnerability(code: str):
59
  """代码安全检测主接口"""
60
  try:
61
- # 输入处理
 
 
 
 
 
 
62
  code = code[:2000] # 截断超长输入
63
 
64
  # 模型推理
@@ -66,20 +72,28 @@ async def detect_vulnerability(code: str):
66
  code,
67
  return_tensors="pt",
68
  truncation=True,
 
69
  max_length=512
70
  )
 
71
  with torch.no_grad():
72
  outputs = model(**inputs)
73
-
74
  # 结果解析
75
- label_id = outputs.logits.argmax().item()
 
 
 
 
 
76
  return {
77
  "label": label_id, # 0:安全 1:不安全
78
- "confidence": outputs.logits.softmax(dim=-1)[0][label_id].item()
79
  }
80
 
81
  except Exception as e:
 
82
  return {
83
  "error": str(e),
84
- "tip": "请检查输入代码是否包含非ASCII字符"
85
- }
 
1
+ from fastapi import FastAPI, Body
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
 
55
 
56
  # === 核心检测接口 ===
57
  @app.post("/detect")
58
+ async def detect_vulnerability(payload: dict = Body(...)):
59
  """代码安全检测主接口"""
60
  try:
61
+ # 获取 JSON 输入数据
62
+ code = payload.get("code", "").strip()
63
+
64
+ if not code:
65
+ return {"error": "代码内容为空", "tip": "请提供有效的代码字符串"}
66
+
67
+ # 限制代码长度
68
  code = code[:2000] # 截断超长输入
69
 
70
  # 模型推理
 
72
  code,
73
  return_tensors="pt",
74
  truncation=True,
75
+ padding=True, # 自动选择填充策略
76
  max_length=512
77
  )
78
+
79
  with torch.no_grad():
80
  outputs = model(**inputs)
81
+
82
  # 结果解析
83
+ logits = outputs.logits
84
+ label_id = logits.argmax().item()
85
+ confidence = logits.softmax(dim=-1)[0][label_id].item()
86
+
87
+ logger.info(f"Code analyzed. Logits: {logits.tolist()}, Prediction: {label_id}, Confidence: {confidence:.4f}")
88
+
89
  return {
90
  "label": label_id, # 0:安全 1:不安全
91
+ "confidence": round(confidence, 4)
92
  }
93
 
94
  except Exception as e:
95
+ logger.error("Error during model inference: %s", str(e))
96
  return {
97
  "error": str(e),
98
+ "tip": "请检查输入代码是否包含非ASCII字符或格式错误"
99
+ }