Forrest99 commited on
Commit
e28e6dd
·
verified ·
1 Parent(s): b6e0b3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -25
app.py CHANGED
@@ -1,19 +1,18 @@
1
  from fastapi import FastAPI
2
- from fastapi.middleware.cors import CORSMiddleware # 新增 CORS 支持
3
- import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
6
  import os
7
 
8
- # === FastAPI 初始化 ===
9
  app = FastAPI()
10
 
11
- # 添加 CORS 中间件(关键步骤)
12
  app.add_middleware(
13
  CORSMiddleware,
14
- allow_origins=["*"], # 允许所有来源
15
- allow_methods=["*"], # 允许所有 HTTP 方法
16
- allow_headers=["*"], # 允许所有请求头
17
  )
18
 
19
  # === 模型加载 ===
@@ -23,30 +22,15 @@ tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detec
23
 
24
  # === HTTP API 接口 ===
25
  @app.post("/detect")
26
- async def api_detect(code: str):
27
- """HTTP API 接口"""
28
  try:
29
  inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512)
30
  with torch.no_grad():
31
  outputs = model(**inputs)
32
  label_id = outputs.logits.argmax().item()
33
  return {
34
- "label": int(label_id), # 强制返回 0/1 数字
35
  "score": outputs.logits.softmax(dim=-1)[0][label_id].item()
36
  }
37
  except Exception as e:
38
- return {"error": str(e)}
39
-
40
- # === Gradio 界面(可选)===
41
- def gradio_predict(code: str):
42
- result = api_detect(code)
43
- return f"Prediction: {result['label']} (Confidence: {result['score']:.2f})"
44
-
45
- gr_interface = gr.Interface(
46
- fn=gradio_predict,
47
- inputs=gr.Textbox(lines=10, placeholder="Paste code here..."),
48
- outputs="text",
49
- title="Code Security Detector"
50
- )
51
-
52
- app = gr.mount_gradio_app(app, gr_interface, path="/")
 
1
  from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
5
  import os
6
 
7
+ # === FastAPI 配置 ===
8
  app = FastAPI()
9
 
10
+ # 解决 CSP 限制的关键配置
11
  app.add_middleware(
12
  CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
  )
17
 
18
  # === 模型加载 ===
 
22
 
23
  # === HTTP API 接口 ===
24
  @app.post("/detect")
25
+ async def detect(code: str):
 
26
  try:
27
  inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512)
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
  label_id = outputs.logits.argmax().item()
31
  return {
32
+ "label": int(label_id), # 严格返回 0/1
33
  "score": outputs.logits.softmax(dim=-1)[0][label_id].item()
34
  }
35
  except Exception as e:
36
+ return {"error": str(e)}