dzenzzz commited on
Commit
6152e61
·
1 Parent(s): df02cd1

adds middlewares

Browse files
Files changed (2) hide show
  1. app.py +34 -3
  2. config.py +16 -12
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import asyncio
2
  import time
3
- from fastapi import FastAPI, Request
4
  from fastapi.responses import JSONResponse
5
  from starlette.status import HTTP_504_GATEWAY_TIMEOUT
6
  from neural_searcher import NeuralSearcher
 
7
  from huggingface_hub import login
8
- from config import HUGGING_FACE_API_KEY,COLLECTION_NAME
9
- import os
10
 
11
  login(HUGGING_FACE_API_KEY)
12
 
@@ -16,12 +16,43 @@ neural_searcher = NeuralSearcher(collection_name=COLLECTION_NAME)
16
 
17
  REQUEST_TIMEOUT_ERROR = 30
18
 
 
 
 
 
19
  @app.get("/api/search")
20
  async def search(q: str):
21
  data = await neural_searcher.search(text=q)
22
  return data
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @app.middleware("http")
26
  async def timeout_middleware(request: Request, call_next):
27
  try:
 
1
  import asyncio
2
  import time
3
+ from fastapi import FastAPI, Request, HTTPException
4
  from fastapi.responses import JSONResponse
5
  from starlette.status import HTTP_504_GATEWAY_TIMEOUT
6
  from neural_searcher import NeuralSearcher
7
+ from fastapi.middleware.cors import CORSMiddleware
8
  from huggingface_hub import login
9
+ from config import HUGGING_FACE_API_KEY,COLLECTION_NAME, ALLOWED_ORIGINS, API_KEY
 
10
 
11
  login(HUGGING_FACE_API_KEY)
12
 
 
16
 
17
  REQUEST_TIMEOUT_ERROR = 30
18
 
19
+ ALLOWED_ORIGINS = [ALLOWED_ORIGINS]
20
+ ALLOWED_API_KEY = API_KEY
21
+
22
+
23
  @app.get("/api/search")
24
  async def search(q: str):
25
  data = await neural_searcher.search(text=q)
26
  return data
27
 
28
 
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=ALLOWED_ORIGINS,
32
+ allow_credentials=True,
33
+ allow_methods=["GET"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ @app.middleware("http")
38
+ async def security_middleware(request: Request, call_next):
39
+ referer = request.headers.get("referer", "")
40
+ origin = request.headers.get("origin", "")
41
+ user_agent = request.headers.get("user-agent", "")
42
+ api_key = request.headers.get("X-API-KEY", "")
43
+
44
+
45
+ if not (referer.startswith(ALLOWED_ORIGINS[0]) or origin.startswith(ALLOWED_ORIGINS[0])):
46
+ raise HTTPException(status_code=403, detail="Access denied: Invalid source")
47
+
48
+ if not user_agent or "Mozilla" not in user_agent:
49
+ raise HTTPException(status_code=403, detail="Access denied: Suspicious client")
50
+
51
+ if api_key != ALLOWED_API_KEY:
52
+ raise HTTPException(status_code=403, detail="Access denied: Invalid API Key")
53
+
54
+ return await call_next(request)
55
+
56
  @app.middleware("http")
57
  async def timeout_middleware(request: Request, call_next):
58
  try:
config.py CHANGED
@@ -1,18 +1,22 @@
1
  import os
2
 
3
- from dotenv import find_dotenv, load_dotenv
4
 
5
- dotenv_path = find_dotenv()
6
- load_dotenv(dotenv_path)
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- API_KEY = os.getenv("API_KEY")
9
- HOST = os.getenv("HOST")
10
 
11
- COLLECTION_NAME = os.getenv("COLLECTION_NAME")
12
- TEXT_FIELD_NAME = os.getenv("TEXT_FIELD_NAME")
13
 
14
- DENSE_MODEL = os.getenv("DENSE_MODEL")
15
- SPARSE_MODEL = os.getenv("SPARSE_MODEL")
16
-
17
- DENSE_MODEL_SHORT = os.getenv("DENSE_MODEL_SHORT")
18
- SPARSE_MODEL_SHORT = os.getenv("SPARSE_MODEL_SHORT")
 
1
  import os
2
 
 
3
 
4
+ HUGGING_FACE_API_KEY = os.getenv('HUGGING_FACE_API_KEY')
5
+
6
+ COLLECTION_NAME = os.getenv('COLLECTION_NAME')
7
+
8
+ QDRANT_URL = os.getenv('QDRANT_URL')
9
+
10
+ QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')
11
+
12
+ DENSE_MODEL = os.getenv('DENSE_MODEL')
13
+
14
+ SPARSE_MODEL = os.getenv('SPARSE_MODEL')
15
+
16
+ LATE_INTERACTION_MODEL = os.getenv('LATE_INTERACTION_MODEL')
17
 
18
+ NER_MODEL = os.getenv('NER_MODEL')
 
19
 
20
+ ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS')
 
21
 
22
+ API_KEY = os.getenv('API_KEY')