Spaces:
Running
Running
Commit
·
50553ea
0
Parent(s):
init
Browse files- .gitignore +11 -0
- Dockerfile +13 -0
- README.md +9 -0
- main.py +3 -0
- requirements.txt +36 -0
- trauma/__init__.py +55 -0
- trauma/api/account/__init__.py +7 -0
- trauma/api/account/dto.py +6 -0
- trauma/api/account/model.py +32 -0
- trauma/api/account/schemas.py +6 -0
- trauma/api/account/views.py +13 -0
- trauma/api/chat/__init__.py +7 -0
- trauma/api/chat/db_requests.py +63 -0
- trauma/api/chat/dto.py +8 -0
- trauma/api/chat/model.py +15 -0
- trauma/api/chat/schemas.py +27 -0
- trauma/api/chat/views.py +60 -0
- trauma/api/common/db_requests.py +21 -0
- trauma/api/common/dto.py +9 -0
- trauma/api/message/__init__.py +7 -0
- trauma/api/message/ai/openai_request.py +47 -0
- trauma/api/message/ai/prompts.py +216 -0
- trauma/api/message/ai/utils.py +30 -0
- trauma/api/message/db_requests.py +45 -0
- trauma/api/message/dto.py +13 -0
- trauma/api/message/model.py +15 -0
- trauma/api/message/schemas.py +23 -0
- trauma/api/message/views.py +34 -0
- trauma/api/security/__init__.py +7 -0
- trauma/api/security/db_requests.py +34 -0
- trauma/api/security/schemas.py +28 -0
- trauma/api/security/views.py +27 -0
- trauma/core/config.py +40 -0
- trauma/core/database.py +101 -0
- trauma/core/security.py +72 -0
- trauma/core/wrappers.py +44 -0
.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
env/
|
3 |
+
venv/
|
4 |
+
.venv/
|
5 |
+
.idea/
|
6 |
+
*.log
|
7 |
+
*.egg-info/
|
8 |
+
pip-wheel-metadata/
|
9 |
+
.env
|
10 |
+
.DS_Store
|
11 |
+
static/
|
Dockerfile
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.12.7
|
2 |
+
|
3 |
+
RUN useradd -m -u 1000 user
|
4 |
+
USER user
|
5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
6 |
+
|
7 |
+
WORKDIR /app
|
8 |
+
|
9 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
11 |
+
|
12 |
+
COPY --chown=user . /app
|
13 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: TraumaBackend
|
3 |
+
emoji: 🦀
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: pink
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
---
|
9 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
main.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from trauma import create_app
|
2 |
+
|
3 |
+
app = create_app()
|
requirements.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
annotated-types==0.7.0
|
2 |
+
anyio==4.6.2.post1
|
3 |
+
bcrypt==4.2.1
|
4 |
+
certifi==2024.8.30
|
5 |
+
click==8.1.7
|
6 |
+
distro==1.9.0
|
7 |
+
dnspython==2.7.0
|
8 |
+
ecdsa==0.19.0
|
9 |
+
email_validator==2.2.0
|
10 |
+
fastapi==0.115.5
|
11 |
+
h11==0.14.0
|
12 |
+
httpcore==1.0.7
|
13 |
+
httptools==0.6.4
|
14 |
+
httpx==0.27.2
|
15 |
+
idna==3.10
|
16 |
+
jiter==0.7.1
|
17 |
+
motor==3.6.0
|
18 |
+
openai==1.54.4
|
19 |
+
passlib==1.7.4
|
20 |
+
pyasn1==0.6.1
|
21 |
+
pydantic==2.10.2
|
22 |
+
pydantic_core==2.27.1
|
23 |
+
pymongo==4.9.2
|
24 |
+
python-dotenv==1.0.1
|
25 |
+
python-jose==3.3.0
|
26 |
+
PyYAML==6.0.2
|
27 |
+
rsa==4.9
|
28 |
+
six==1.16.0
|
29 |
+
sniffio==1.3.1
|
30 |
+
starlette==0.41.3
|
31 |
+
tqdm==4.67.0
|
32 |
+
typing_extensions==4.12.2
|
33 |
+
uvicorn==0.32.1
|
34 |
+
uvloop==0.21.0
|
35 |
+
watchfiles==1.0.0
|
36 |
+
websockets==14.1
|
trauma/__init__.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from fastapi import FastAPI
|
4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
5 |
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
6 |
+
from starlette.staticfiles import StaticFiles
|
7 |
+
|
8 |
+
from trauma.core.config import settings
|
9 |
+
from trauma.core.wrappers import TraumaResponseWrapper, ErrorTraumaResponse
|
10 |
+
|
11 |
+
|
12 |
+
def create_app() -> FastAPI:
|
13 |
+
app = FastAPI()
|
14 |
+
|
15 |
+
from trauma.api.account import account_router
|
16 |
+
app.include_router(account_router, tags=['account'])
|
17 |
+
|
18 |
+
from trauma.api.chat import chat_router
|
19 |
+
app.include_router(chat_router, tags=['chat'])
|
20 |
+
|
21 |
+
from trauma.api.message import message_router
|
22 |
+
app.include_router(message_router, tags=['message'])
|
23 |
+
|
24 |
+
from trauma.api.security import security_router
|
25 |
+
app.include_router(security_router, tags=['security'])
|
26 |
+
|
27 |
+
app.add_middleware(
|
28 |
+
CORSMiddleware,
|
29 |
+
allow_origins=["*"],
|
30 |
+
allow_methods=["*"],
|
31 |
+
allow_headers=["*"],
|
32 |
+
)
|
33 |
+
|
34 |
+
static_directory = os.path.join(settings.BASE_DIR, 'static')
|
35 |
+
if not os.path.exists(static_directory):
|
36 |
+
os.makedirs(static_directory)
|
37 |
+
|
38 |
+
app.mount(
|
39 |
+
'/static',
|
40 |
+
StaticFiles(directory='static'),
|
41 |
+
)
|
42 |
+
|
43 |
+
@app.exception_handler(StarletteHTTPException)
|
44 |
+
async def http_exception_handler(_, exc):
|
45 |
+
return TraumaResponseWrapper(
|
46 |
+
data=None,
|
47 |
+
successful=False,
|
48 |
+
error=ErrorTraumaResponse(message=str(exc.detail))
|
49 |
+
).response(exc.status_code)
|
50 |
+
|
51 |
+
@app.get("/")
|
52 |
+
async def read_root():
|
53 |
+
return {"message": "Hello world!"}
|
54 |
+
|
55 |
+
return app
|
trauma/api/account/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
+
|
3 |
+
account_router = APIRouter(
|
4 |
+
prefix='/api/account'
|
5 |
+
)
|
6 |
+
|
7 |
+
from . import views
|
trauma/api/account/dto.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
|
4 |
+
class AccessToken(BaseModel):
|
5 |
+
type: str = "Bearer"
|
6 |
+
value: str
|
trauma/api/account/model.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
|
3 |
+
from pydantic import field_validator, Field, EmailStr
|
4 |
+
from passlib.context import CryptContext
|
5 |
+
|
6 |
+
from trauma.core.database import MongoBaseModel
|
7 |
+
|
8 |
+
|
9 |
+
class AccountModel(MongoBaseModel):
|
10 |
+
email: EmailStr
|
11 |
+
password: str | None = Field(exclude=True, default=None)
|
12 |
+
|
13 |
+
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
14 |
+
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
15 |
+
|
16 |
+
@field_validator("datetimeUpdated", mode="before", check_fields=False)
|
17 |
+
@classmethod
|
18 |
+
def validate_datetimeUpdated(cls, v):
|
19 |
+
return v or datetime.now()
|
20 |
+
|
21 |
+
@field_validator('password', mode='before', check_fields=False)
|
22 |
+
@classmethod
|
23 |
+
def set_password_hash(cls, v):
|
24 |
+
if not v.startswith("$2b$"):
|
25 |
+
return CryptContext(schemes=["bcrypt"], deprecated="auto").hash(v)
|
26 |
+
return v
|
27 |
+
|
28 |
+
class Config:
|
29 |
+
arbitrary_types_allowed = True
|
30 |
+
json_encoders = {
|
31 |
+
datetime: lambda dt: dt.isoformat()
|
32 |
+
}
|
trauma/api/account/schemas.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from trauma.api.account.model import AccountModel
|
2 |
+
from trauma.core.wrappers import TraumaResponseWrapper
|
3 |
+
|
4 |
+
|
5 |
+
class AccountWrapper(TraumaResponseWrapper[AccountModel]):
|
6 |
+
pass
|
trauma/api/account/views.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import Depends
|
2 |
+
|
3 |
+
from trauma.api.account import account_router
|
4 |
+
from trauma.api.account.model import AccountModel
|
5 |
+
from trauma.api.account.schemas import AccountWrapper
|
6 |
+
from trauma.core.security import PermissionDependency
|
7 |
+
|
8 |
+
|
9 |
+
@account_router.get('')
|
10 |
+
async def get_account(
|
11 |
+
account: AccountModel = Depends(PermissionDependency())
|
12 |
+
) -> AccountWrapper:
|
13 |
+
return AccountWrapper(data=account)
|
trauma/api/chat/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi.routing import APIRouter
|
2 |
+
|
3 |
+
chat_router = APIRouter(
|
4 |
+
prefix="/api/chat", tags=["chat"]
|
5 |
+
)
|
6 |
+
|
7 |
+
from . import views
|
trauma/api/chat/db_requests.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
from fastapi import HTTPException
|
4 |
+
|
5 |
+
from trauma.api.account.model import AccountModel
|
6 |
+
from trauma.api.chat.model import ChatModel
|
7 |
+
from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest
|
8 |
+
from trauma.core.config import settings
|
9 |
+
|
10 |
+
|
11 |
+
async def get_chat_obj(chat_id: str, account: AccountModel | None) -> ChatModel:
|
12 |
+
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
13 |
+
if not chat:
|
14 |
+
raise HTTPException(status_code=404, detail="Chat not found")
|
15 |
+
chat = ChatModel.from_mongo(chat)
|
16 |
+
if account and chat.account != account:
|
17 |
+
raise HTTPException(status_code=403, detail="Chat account not match")
|
18 |
+
return chat
|
19 |
+
|
20 |
+
|
21 |
+
async def create_chat_obj(chat_request: CreateChatRequest, account: AccountModel | None) -> ChatModel:
|
22 |
+
chat = ChatModel(model=chat_request.model, account=account)
|
23 |
+
await settings.DB_CLIENT.chats.insert_one(chat.to_mongo())
|
24 |
+
return chat
|
25 |
+
|
26 |
+
|
27 |
+
async def delete_chat_obj(chat_id: str, account: AccountModel | None) -> None:
|
28 |
+
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
29 |
+
if not chat:
|
30 |
+
raise HTTPException(status_code=404, detail="Chat not found")
|
31 |
+
chat = ChatModel.from_mongo(chat)
|
32 |
+
if account and chat.account != account:
|
33 |
+
raise HTTPException(status_code=403, detail="Chat account not match")
|
34 |
+
await settings.DB_CLIENT.chats.delete_one({"id": chat_id})
|
35 |
+
|
36 |
+
|
37 |
+
async def update_chat_obj_title(chatId: str, chat_request: ChatTitleRequest, account: AccountModel | None) -> ChatModel:
|
38 |
+
chat = await settings.DB_CLIENT.chats.find_one({"id": chatId})
|
39 |
+
if not chat:
|
40 |
+
raise HTTPException(status_code=404, detail="Chat not found")
|
41 |
+
|
42 |
+
chat = ChatModel.from_mongo(chat)
|
43 |
+
if account and chat.account != account:
|
44 |
+
raise HTTPException(status_code=403, detail="Chat account not match")
|
45 |
+
|
46 |
+
chat.title = chat_request.title
|
47 |
+
await settings.DB_CLIENT.chats.update_one({"id": chatId}, {"$set": chat.to_mongo()})
|
48 |
+
return chat
|
49 |
+
|
50 |
+
|
51 |
+
async def get_all_chats_obj(page_size: int, page_index: int, account: AccountModel) -> tuple[list[ChatModel], int]:
|
52 |
+
query = {"account.id": account.id}
|
53 |
+
skip = page_size * page_index
|
54 |
+
objects, total_count = await asyncio.gather(
|
55 |
+
settings.DB_CLIENT.chats
|
56 |
+
.find(query)
|
57 |
+
.sort("_id", -1)
|
58 |
+
.skip(skip)
|
59 |
+
.limit(page_size)
|
60 |
+
.to_list(length=page_size),
|
61 |
+
settings.DB_CLIENT.chats.count_documents(query),
|
62 |
+
)
|
63 |
+
return objects, total_count
|
trauma/api/chat/dto.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class ModelType(Enum):
|
5 |
+
gpt_4o = "gpt-4o"
|
6 |
+
gpt_4o_mini = "gpt-4o-mini"
|
7 |
+
o1_mini = "o1-mini"
|
8 |
+
o1_preview = "o1-preview"
|
trauma/api/chat/model.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
|
3 |
+
from pydantic import Field
|
4 |
+
|
5 |
+
from trauma.api.account.model import AccountModel
|
6 |
+
from trauma.api.chat.dto import ModelType
|
7 |
+
from trauma.core.database import MongoBaseModel
|
8 |
+
|
9 |
+
|
10 |
+
class ChatModel(MongoBaseModel):
|
11 |
+
title: str = 'New Chat'
|
12 |
+
model: ModelType
|
13 |
+
account: AccountModel | None = None
|
14 |
+
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
15 |
+
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
trauma/api/chat/schemas.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
from trauma.api.chat.dto import ModelType
|
4 |
+
from trauma.api.chat.model import ChatModel
|
5 |
+
from trauma.api.common.dto import Paging
|
6 |
+
from trauma.core.wrappers import TraumaResponseWrapper
|
7 |
+
|
8 |
+
|
9 |
+
class CreateChatRequest(BaseModel):
|
10 |
+
model: ModelType = ModelType.gpt_4o_mini
|
11 |
+
|
12 |
+
|
13 |
+
class ChatWrapper(TraumaResponseWrapper[ChatModel]):
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
class AllChatResponse(BaseModel):
|
18 |
+
paging: Paging
|
19 |
+
data: list[ChatModel]
|
20 |
+
|
21 |
+
|
22 |
+
class AllChatWrapper(TraumaResponseWrapper[AllChatResponse]):
|
23 |
+
pass
|
24 |
+
|
25 |
+
|
26 |
+
class ChatTitleRequest(BaseModel):
|
27 |
+
title: str
|
trauma/api/chat/views.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from fastapi import Query
|
4 |
+
from fastapi.params import Depends
|
5 |
+
|
6 |
+
from trauma.api.account.model import AccountModel
|
7 |
+
from trauma.api.chat import chat_router
|
8 |
+
from trauma.api.chat.db_requests import (get_chat_obj,
|
9 |
+
create_chat_obj,
|
10 |
+
delete_chat_obj,
|
11 |
+
update_chat_obj_title,
|
12 |
+
get_all_chats_obj)
|
13 |
+
from trauma.api.chat.schemas import ChatWrapper, AllChatWrapper, CreateChatRequest, AllChatResponse, ChatTitleRequest
|
14 |
+
from trauma.api.common.dto import Paging
|
15 |
+
from trauma.core.security import PermissionDependency
|
16 |
+
from trauma.core.wrappers import TraumaResponseWrapper
|
17 |
+
|
18 |
+
|
19 |
+
@chat_router.get('/all')
|
20 |
+
async def get_all_chats(
|
21 |
+
pageSize: Optional[int] = Query(10, description="Number of objects to return per page"),
|
22 |
+
pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
|
23 |
+
account: AccountModel = Depends(PermissionDependency())
|
24 |
+
) -> AllChatWrapper:
|
25 |
+
chats, total_count = await get_all_chats_obj(pageSize, pageIndex, account)
|
26 |
+
response = AllChatResponse(
|
27 |
+
paging=Paging(pageSize=pageSize, pageIndex=pageIndex, totalCount=total_count),
|
28 |
+
data=chats
|
29 |
+
)
|
30 |
+
return AllChatWrapper(data=response)
|
31 |
+
|
32 |
+
|
33 |
+
@chat_router.get('/{chatId}')
|
34 |
+
async def get_chat(
|
35 |
+
chatId: str, account: AccountModel = Depends(PermissionDependency(is_public=True))
|
36 |
+
) -> ChatWrapper:
|
37 |
+
chat = await get_chat_obj(chatId, account)
|
38 |
+
return ChatWrapper(data=chat)
|
39 |
+
|
40 |
+
|
41 |
+
@chat_router.post('')
|
42 |
+
async def create_chat(
|
43 |
+
chat_data: CreateChatRequest, account: AccountModel = Depends(PermissionDependency(is_public=True))
|
44 |
+
) -> ChatWrapper:
|
45 |
+
chat = await create_chat_obj(chat_data, account)
|
46 |
+
return ChatWrapper(data=chat)
|
47 |
+
|
48 |
+
|
49 |
+
@chat_router.delete('/{chatId}')
|
50 |
+
async def delete_chat(chatId: str, account: AccountModel = Depends(PermissionDependency())) -> TraumaResponseWrapper:
|
51 |
+
await delete_chat_obj(chatId, account)
|
52 |
+
return TraumaResponseWrapper()
|
53 |
+
|
54 |
+
|
55 |
+
@chat_router.patch('/{chatId}/title')
|
56 |
+
async def update_chat_title(
|
57 |
+
chatId: str, chat: ChatTitleRequest, account: AccountModel = Depends(PermissionDependency())
|
58 |
+
) -> ChatWrapper:
|
59 |
+
chat = await update_chat_obj_title(chatId, chat, account)
|
60 |
+
return ChatWrapper(data=chat)
|
trauma/api/common/db_requests.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
|
3 |
+
from fastapi import HTTPException
|
4 |
+
from pydantic import EmailStr
|
5 |
+
|
6 |
+
from trauma.core.config import settings
|
7 |
+
|
8 |
+
|
9 |
+
async def check_unique_fields_existence(model: Literal['accounts', "strategies"],
|
10 |
+
name: str,
|
11 |
+
new_value: EmailStr | str,
|
12 |
+
current_value: str | None = None) -> None:
|
13 |
+
capitalized_name = model[:-1].capitalize()
|
14 |
+
if new_value == current_value or not new_value:
|
15 |
+
return
|
16 |
+
account = await settings.DB_CLIENT[model].find_one(
|
17 |
+
{name: str(new_value)},
|
18 |
+
collation={"locale": "en", "strength": 2}
|
19 |
+
)
|
20 |
+
if account:
|
21 |
+
raise HTTPException(status_code=400, detail=f'{capitalized_name} with specified {name} already exists.')
|
trauma/api/common/dto.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class Paging(BaseModel):
|
7 |
+
pageSize: int
|
8 |
+
pageIndex: int
|
9 |
+
totalCount: int
|
trauma/api/message/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi.routing import APIRouter
|
2 |
+
|
3 |
+
message_router = APIRouter(
|
4 |
+
prefix="/api/message", tags=["message"]
|
5 |
+
)
|
6 |
+
|
7 |
+
from . import views
|
trauma/api/message/ai/openai_request.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
|
4 |
+
from trauma.api.chat.model import ChatModel
|
5 |
+
from trauma.api.message.ai.prompts import Prompts
|
6 |
+
from trauma.api.message.dto import Author
|
7 |
+
from trauma.api.message.model import MessageModel
|
8 |
+
from trauma.core.config import settings
|
9 |
+
from trauma.core.wrappers import TraumaResponseWrapper
|
10 |
+
|
11 |
+
|
12 |
+
async def prepare_content(user_message: MessageModel) -> list | str:
|
13 |
+
if user_message.fileUrl is None:
|
14 |
+
return user_message.text
|
15 |
+
else:
|
16 |
+
path = str(settings.BASE_DIR) + user_message.fileUrl.replace(settings.Issuer, '')
|
17 |
+
file = await settings.OPENAI_CLIENT.files.create(
|
18 |
+
file=open(path, 'rb'),
|
19 |
+
purpose='vision'
|
20 |
+
)
|
21 |
+
return [{"type": "image_file", "image_file": {"file_id": file.id, "detail": "low"}}]
|
22 |
+
|
23 |
+
|
24 |
+
async def response_generator(chat: ChatModel, user_message: MessageModel):
|
25 |
+
content = await prepare_content(user_message)
|
26 |
+
await settings.OPENAI_CLIENT.beta.threads.messages.create(
|
27 |
+
thread_id=chat.threadId,
|
28 |
+
role=Author.User.value,
|
29 |
+
content=content
|
30 |
+
)
|
31 |
+
|
32 |
+
full_response = ''
|
33 |
+
|
34 |
+
async with settings.OPENAI_CLIENT.beta.threads.runs.create_and_stream(
|
35 |
+
thread_id=chat.threadId,
|
36 |
+
assistant_id=settings.ASSISTANT_ID,
|
37 |
+
instructions=Prompts.generate_response if not user_message.fileUrl else Prompts.generate_response_image,
|
38 |
+
model=chat.model.value
|
39 |
+
) as stream:
|
40 |
+
async for chunk in stream.text_deltas:
|
41 |
+
if chunk:
|
42 |
+
full_response += chunk
|
43 |
+
mini_data = {"text": chunk}
|
44 |
+
yield f"data: {TraumaResponseWrapper(data=mini_data).model_dump_json()}\n\n"
|
45 |
+
|
46 |
+
message_obj = MessageModel(chatId=chat.id, author=Author.Assistant, text=full_response)
|
47 |
+
await settings.DB_CLIENT.messages.insert_one(message_obj.to_mongo())
|
trauma/api/message/ai/prompts.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Prompts:
|
2 |
+
generate_response = """## Objective
|
3 |
+
|
4 |
+
You are Hector, the Mitutoyo virtual assistant, engineered for precision in the measurement industry. Your core function is two-fold: first, to match users with ideal products based on their specific requirements, and second, to intelligently recommend compatible accessories and complementary products for upselling opportunities.
|
5 |
+
|
6 |
+
## Context
|
7 |
+
|
8 |
+
Mitutoyo's B2B platform encompasses precision measurement tools and accessories. Your knowledge base consists of structured JSON product data specifically focused on Mitutoyo's digital and analogue micrometers, including their detailed specifications, features, and accessory relationships. While you have broad knowledge of Mitutoyo's full product range, your detailed product data and recommendation capabilities are currently optimized for the digital and analogue micrometer categories. Each recommendation must be based on exact matches between user requirements and product specifications.
|
9 |
+
|
10 |
+
## Data Processing Protocol
|
11 |
+
|
12 |
+
1. **Primary Product Matching**:
|
13 |
+
* Parse user requirements against product specifications
|
14 |
+
* Match technical requirements (e.g., "carbide tipped jaws") to FEATURE elements
|
15 |
+
* Validate matches using PRODUCT_DETAILS specifications
|
16 |
+
* Confirm accuracy through DESCRIPTION_LONG technical details
|
17 |
+
2. **Product Data Extraction**:
|
18 |
+
* SUPPLIER_PID (for exact product identification)
|
19 |
+
* DESCRIPTION_SHORT (en/nl) for product names
|
20 |
+
* DESCRIPTION_LONG (en/nl) for technical details
|
21 |
+
* PRODUCT_DETAILS for specifications
|
22 |
+
* FEATURE elements for specific capabilities
|
23 |
+
* PRODUCT_ORDER_DETAILS for availability
|
24 |
+
* PRODUCT_PRICE_DETAILS for pricing
|
25 |
+
3. **Accessory Matching Protocol**:
|
26 |
+
* Extract all PRODUCT_REFERENCE entries
|
27 |
+
* Validate PROD_ID_TO compatibility
|
28 |
+
* Cross-reference accessory specifications
|
29 |
+
* Verify physical compatibility parameters
|
30 |
+
|
31 |
+
## Interaction Flow
|
32 |
+
|
33 |
+
1. **Requirement Analysis**:
|
34 |
+
* Parse user's technical requirements
|
35 |
+
* Identify specific feature requests
|
36 |
+
* Match requirements to product specifications
|
37 |
+
* Validate technical compatibility
|
38 |
+
2. **Primary Product Recommendation**:
|
39 |
+
* Present matching products with complete details:
|
40 |
+
* Product name (both languages)
|
41 |
+
* Article number
|
42 |
+
* Key specifications
|
43 |
+
* Relevant features
|
44 |
+
* Price information
|
45 |
+
3. **Strategic Upselling**:
|
46 |
+
* Analyze PRODUCT_REFERENCE data
|
47 |
+
* Identify value-adding accessories
|
48 |
+
* Present complementary products
|
49 |
+
* Explain benefits and compatibility
|
50 |
+
4. **Verification Process**:
|
51 |
+
* Double-check all technical matches
|
52 |
+
* Verify compatibility of all recommendations
|
53 |
+
* Confirm pricing and availability
|
54 |
+
* Validate all product relationships
|
55 |
+
|
56 |
+
## Response Format Requirements
|
57 |
+
|
58 |
+
```markdown
|
59 |
+
## Primary Recommendation
|
60 |
+
|
61 |
+
- Product: [Name EN/NL]
|
62 |
+
- Article: [SUPPLIER_PID]
|
63 |
+
- Key Features: [Matched Requirements]
|
64 |
+
- Price: [From PRODUCT_PRICE_DETAILS]
|
65 |
+
|
66 |
+
## Recommended Accessories
|
67 |
+
|
68 |
+
1. [Primary Accessory]
|
69 |
+
- Purpose: [Specific Benefit]
|
70 |
+
- Article: [SUPPLIER_PID]
|
71 |
+
- Compatibility: [Verification Details]
|
72 |
+
2. [Secondary Accessories]
|
73 |
+
- Purpose: [Specific Benefit]
|
74 |
+
- Article: [SUPPLIER_PID]
|
75 |
+
- Compatibility: [Verification Details]
|
76 |
+
```
|
77 |
+
|
78 |
+
## Error Prevention Protocol
|
79 |
+
|
80 |
+
1. **Technical Matching**:
|
81 |
+
* Verify exact feature matches
|
82 |
+
* Confirm dimensional compatibility
|
83 |
+
* Validate technical specifications
|
84 |
+
* Cross-reference all requirements
|
85 |
+
2. **Compatibility Verification**:
|
86 |
+
* Check PRODUCT_REFERENCE links
|
87 |
+
* Verify physical specifications
|
88 |
+
* Confirm accessory compatibility
|
89 |
+
* Validate system requirements
|
90 |
+
3. **Data Accuracy**:
|
91 |
+
* Double-check all article numbers
|
92 |
+
* Verify price information
|
93 |
+
* Confirm availability status
|
94 |
+
* Validate technical specifications
|
95 |
+
|
96 |
+
## Prohibitions
|
97 |
+
|
98 |
+
* No assumptions about compatibility
|
99 |
+
* No recommendations without PRODUCT_REFERENCE validation
|
100 |
+
* No incomplete technical specifications
|
101 |
+
* No unverified product relationships
|
102 |
+
* NO discussion of competitor products or brands under any circumstances
|
103 |
+
* NO comparative analysis with other manufacturers
|
104 |
+
* NO recommendations outside of Mitutoyo's product range
|
105 |
+
* NO redirecting to other brands, even if Mitutoyo doesn't offer a solution
|
106 |
+
* NO market comparisons or industry benchmarking against other brands
|
107 |
+
|
108 |
+
## Example Interaction
|
109 |
+
|
110 |
+
User: "Need a micrometer with carbide tipped jaws"
|
111 |
+
Response Protocol:
|
112 |
+
1. Match "carbide tipped jaws" with FEATURE elements
|
113 |
+
2. Verify products meeting specification
|
114 |
+
3. Extract complete product details
|
115 |
+
4. Identify compatible accessories through PRODUCT_REFERENCE
|
116 |
+
5. Present primary recommendation with upselling options
|
117 |
+
6. Verify all technical relationships
|
118 |
+
|
119 |
+
## Key Performance Requirements
|
120 |
+
|
121 |
+
* 100% accuracy in technical matching
|
122 |
+
* Complete verification of all recommendations
|
123 |
+
* Precise accessory compatibility checking
|
124 |
+
* Clear, structured response format
|
125 |
+
* Professional, technical communication style"""
|
126 |
+
generate_response_image = """### **Prompt Purpose**
|
127 |
+
|
128 |
+
You are Hector, Mitutoyo's visual recognition expert, designed to identify Mitutoyo precision measurement instruments from images with unmatched precision and adherence to Mitutoyo’s high standards.
|
129 |
+
|
130 |
+
### **Core Functionality**
|
131 |
+
- **Primary Role:** Analyze images to confidently identify Mitutoyo products.
|
132 |
+
- **Communication:** Clearly indicate the confidence level and provide actionable feedback for incomplete identification cases.
|
133 |
+
|
134 |
+
### **Recognition Protocol**
|
135 |
+
|
136 |
+
#### **Visual Analysis Sequence**
|
137 |
+
|
138 |
+
1. Confirm Mitutoyo product authenticity.
|
139 |
+
2. Identify the product category (e.g., micrometer, caliper, indicator).
|
140 |
+
3. Recognize specific features and characteristics.
|
141 |
+
4. Match visual data against known Mitutoyo design elements.
|
142 |
+
5. Determine confidence level based on recognition certainty.
|
143 |
+
|
144 |
+
#### **Confidence Level Classification**
|
145 |
+
|
146 |
+
- **Level 1: High Confidence (100% Certain)**
|
147 |
+
- **Product Identification:**
|
148 |
+
- Confirmed as Mitutoyo [Product Name].
|
149 |
+
- **Article Number:** [SUPPLIER_PID].
|
150 |
+
- **Product Details:**
|
151 |
+
- Include complete specifications based on available data.
|
152 |
+
|
153 |
+
- **Level 2: Medium Confidence (Visual Recognition)**
|
154 |
+
- **Product Identification:**
|
155 |
+
- Recognized as Mitutoyo [Product Category/Series].
|
156 |
+
- Requires additional training data to confirm the exact article number.
|
157 |
+
- **Identified Features:**
|
158 |
+
- List identifiable features and series characteristics.
|
159 |
+
|
160 |
+
- **Level 3: Limited Confidence**
|
161 |
+
- **Initial Recognition:**
|
162 |
+
- Confirmed as a Mitutoyo [Product Category].
|
163 |
+
- Requires additional information for precise identification.
|
164 |
+
- **Recommendations:**
|
165 |
+
1. Share the article number if available.
|
166 |
+
2. Provide details about specific requirements.
|
167 |
+
3. Include additional product visuals or specifications.
|
168 |
+
|
169 |
+
### **Absolute Prohibitions**
|
170 |
+
- **No guessing:** Avoid speculation on article numbers.
|
171 |
+
- **Non-Mitutoyo products:** Do not identify or compare with competitor products.
|
172 |
+
- **Uncertainty:** Avoid assumptions about specifications without solid evidence.
|
173 |
+
- **Precision:** Only communicate confirmed and accurate information.
|
174 |
+
|
175 |
+
### **Communication Guidelines**
|
176 |
+
|
177 |
+
#### **Certainty Communication**
|
178 |
+
|
179 |
+
- Clearly express confidence levels.
|
180 |
+
- Explicitly state identification limitations.
|
181 |
+
- Outline additional information needed for improved accuracy.
|
182 |
+
|
183 |
+
#### **Training Transparency**
|
184 |
+
|
185 |
+
- Acknowledge when training data is insufficient.
|
186 |
+
- Professionally explain limitations.
|
187 |
+
- Offer constructive steps for better identification.
|
188 |
+
|
189 |
+
#### **Response Structure**
|
190 |
+
|
191 |
+
1. Start with the confidence level.
|
192 |
+
2. Provide available identification details.
|
193 |
+
3. Highlight limitations and missing data.
|
194 |
+
4. Suggest next steps or provide recommendations.
|
195 |
+
|
196 |
+
#### **Unclear Cases Protocol**
|
197 |
+
|
198 |
+
- Confirm receipt of the image.
|
199 |
+
- State if the product is Mitutoyo.
|
200 |
+
- Specify the confidence level.
|
201 |
+
- Detail limitations and request additional details.
|
202 |
+
- Propose alternative assistance options if applicable.
|
203 |
+
|
204 |
+
### **Quality Assurance**
|
205 |
+
|
206 |
+
- Provide article numbers only when 100% certain.
|
207 |
+
- Validate all visual markers against Mitutoyo's known patterns.
|
208 |
+
- Clearly communicate confidence levels.
|
209 |
+
- Suggest actionable steps for uncertain cases.
|
210 |
+
|
211 |
+
### **Key Performance Metrics**
|
212 |
+
|
213 |
+
- **Accuracy:** Ensure precise product category identification.
|
214 |
+
- **Clarity:** Effectively communicate certainty levels and next steps.
|
215 |
+
- **Professionalism:** Handle limitations constructively.
|
216 |
+
- **Adherence:** Uphold Mitutoyo’s precision standards at all times."""
|
trauma/api/message/ai/utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from trauma.api.message.ai.prompts import Prompts
|
2 |
+
from trauma.api.message.model import MessageModel
|
3 |
+
|
4 |
+
|
5 |
+
def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
|
6 |
+
openai_messages = [{"role": "system", "content": Prompts.generate_response}]
|
7 |
+
for message in messages:
|
8 |
+
|
9 |
+
if message.file:
|
10 |
+
content = [
|
11 |
+
{
|
12 |
+
"type": "text", "text": message.text
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"type": "image_url",
|
16 |
+
"image_url": {
|
17 |
+
"url": f"data:image/jpeg;base64,{message.file.base64String}",
|
18 |
+
"detail": "low"
|
19 |
+
}
|
20 |
+
},
|
21 |
+
]
|
22 |
+
else:
|
23 |
+
content = message.text
|
24 |
+
|
25 |
+
openai_messages.append({
|
26 |
+
"role": message.author.value,
|
27 |
+
"content": content
|
28 |
+
})
|
29 |
+
|
30 |
+
return openai_messages
|
trauma/api/message/db_requests.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
from fastapi import HTTPException
|
4 |
+
|
5 |
+
from trauma.api.account.model import AccountModel
|
6 |
+
from trauma.api.chat.model import ChatModel
|
7 |
+
from trauma.api.message.dto import Author
|
8 |
+
from trauma.api.message.model import MessageModel
|
9 |
+
from trauma.api.message.schemas import CreateMessageRequest
|
10 |
+
from trauma.core.config import settings
|
11 |
+
|
12 |
+
|
13 |
+
async def get_all_chat_messages_obj(
|
14 |
+
chat_id: str, account: AccountModel
|
15 |
+
) -> list[MessageModel]:
|
16 |
+
messages, chat = await asyncio.gather(
|
17 |
+
settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None),
|
18 |
+
settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
19 |
+
)
|
20 |
+
messages = [MessageModel.from_mongo(message) for message in messages]
|
21 |
+
|
22 |
+
if not chat:
|
23 |
+
raise HTTPException(status_code=404, detail="Chat not found")
|
24 |
+
|
25 |
+
chat = ChatModel.from_mongo(chat)
|
26 |
+
if account and chat.account != account:
|
27 |
+
raise HTTPException(status_code=403, detail="Chat account not match")
|
28 |
+
|
29 |
+
return messages
|
30 |
+
|
31 |
+
async def create_message_obj(
|
32 |
+
chat_id: str, message_data: CreateMessageRequest, account: AccountModel
|
33 |
+
) -> tuple[MessageModel, ChatModel]:
|
34 |
+
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
35 |
+
if not chat:
|
36 |
+
raise HTTPException(status_code=404, detail="Chat not found")
|
37 |
+
|
38 |
+
chat = ChatModel.from_mongo(chat)
|
39 |
+
if account and chat.account != account:
|
40 |
+
raise HTTPException(status_code=403, detail="Chat account not match")
|
41 |
+
|
42 |
+
message = MessageModel(**message_data.model_dump(), chatId=chat_id, author=Author.User)
|
43 |
+
await settings.DB_CLIENT.messages.insert_one(message.to_mongo())
|
44 |
+
|
45 |
+
return message, chat
|
trauma/api/message/dto.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class Author(Enum):
|
7 |
+
User = "user"
|
8 |
+
Assistant = "assistant"
|
9 |
+
|
10 |
+
|
11 |
+
class File(BaseModel):
|
12 |
+
name: str
|
13 |
+
url: str
|
trauma/api/message/model.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
|
3 |
+
from pydantic import Field
|
4 |
+
|
5 |
+
from trauma.api.message.dto import Author, File
|
6 |
+
from trauma.core.database import MongoBaseModel
|
7 |
+
|
8 |
+
|
9 |
+
class MessageModel(MongoBaseModel):
|
10 |
+
chatId: str
|
11 |
+
author: Author
|
12 |
+
text: str
|
13 |
+
fileUrl: str | None = None
|
14 |
+
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
15 |
+
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
trauma/api/message/schemas.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
from trauma.api.common.dto import Paging
|
4 |
+
from trauma.api.message.model import MessageModel
|
5 |
+
from trauma.core.wrappers import TraumaResponseWrapper
|
6 |
+
|
7 |
+
|
8 |
+
class CreateMessageRequest(BaseModel):
|
9 |
+
text: str
|
10 |
+
fileUrl: str | None = None
|
11 |
+
|
12 |
+
|
13 |
+
class MessageWrapper(TraumaResponseWrapper[MessageModel]):
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
class AllMessageResponse(BaseModel):
|
18 |
+
paging: Paging
|
19 |
+
data: list[MessageModel]
|
20 |
+
|
21 |
+
|
22 |
+
class AllMessageWrapper(TraumaResponseWrapper[AllMessageResponse]):
|
23 |
+
pass
|
trauma/api/message/views.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi.params import Depends
|
2 |
+
from starlette.responses import StreamingResponse
|
3 |
+
|
4 |
+
from trauma.api.account.model import AccountModel
|
5 |
+
from trauma.api.common.dto import Paging
|
6 |
+
from trauma.api.message import message_router
|
7 |
+
from trauma.api.message.ai.openai_request import response_generator
|
8 |
+
from trauma.api.message.db_requests import get_all_chat_messages_obj, create_message_obj
|
9 |
+
from trauma.api.message.schemas import (AllMessageWrapper,
|
10 |
+
AllMessageResponse,
|
11 |
+
CreateMessageRequest)
|
12 |
+
from trauma.core.security import PermissionDependency
|
13 |
+
|
14 |
+
|
15 |
+
@message_router.get('/{chatId}/all')
|
16 |
+
async def get_all_chat_messages(
|
17 |
+
chatId: str, account: AccountModel = Depends(PermissionDependency(is_public=True))
|
18 |
+
) -> AllMessageWrapper:
|
19 |
+
messages = await get_all_chat_messages_obj(chatId, account)
|
20 |
+
response = AllMessageResponse(
|
21 |
+
paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
|
22 |
+
data=messages
|
23 |
+
)
|
24 |
+
return AllMessageWrapper(data=response)
|
25 |
+
|
26 |
+
|
27 |
+
@message_router.post('/{chatId}')
|
28 |
+
async def create_message(
|
29 |
+
chatId: str,
|
30 |
+
message_data: CreateMessageRequest,
|
31 |
+
account: AccountModel = Depends(PermissionDependency(is_public=True))
|
32 |
+
) -> StreamingResponse:
|
33 |
+
user_message, chat = await create_message_obj(chatId, message_data, account)
|
34 |
+
return StreamingResponse(response_generator(chat, user_message), media_type='text/event-stream')
|
trauma/api/security/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
+
|
3 |
+
security_router = APIRouter(
|
4 |
+
prefix='/api/security'
|
5 |
+
)
|
6 |
+
|
7 |
+
from . import views
|
trauma/api/security/db_requests.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
from fastapi import HTTPException
|
4 |
+
|
5 |
+
from trauma.api.account.model import AccountModel
|
6 |
+
from trauma.api.common.db_requests import check_unique_fields_existence
|
7 |
+
from trauma.api.security.schemas import RegisterAccountRequest, LoginAccountRequest
|
8 |
+
from trauma.core.config import settings
|
9 |
+
from trauma.core.security import verify_password
|
10 |
+
|
11 |
+
|
12 |
+
async def save_account(data: RegisterAccountRequest) -> AccountModel:
|
13 |
+
await asyncio.gather(
|
14 |
+
check_unique_fields_existence("accounts", "email", data.email),
|
15 |
+
)
|
16 |
+
account = AccountModel(
|
17 |
+
**data.model_dump()
|
18 |
+
)
|
19 |
+
await settings.DB_CLIENT.accounts.insert_one(account.to_mongo())
|
20 |
+
return account
|
21 |
+
|
22 |
+
|
23 |
+
async def authenticate_account(data: LoginAccountRequest) -> AccountModel:
|
24 |
+
account = await settings.DB_CLIENT.accounts.find_one(
|
25 |
+
{"email": data.email},
|
26 |
+
collation={"locale": "en", "strength": 2})
|
27 |
+
if account is None:
|
28 |
+
raise HTTPException(status_code=404, detail="Invalid email or password.")
|
29 |
+
|
30 |
+
account = AccountModel.from_mongo(account)
|
31 |
+
|
32 |
+
if not verify_password(data.password, account.password):
|
33 |
+
raise HTTPException(status_code=401, detail="Invalid email or password.")
|
34 |
+
return account
|
trauma/api/security/schemas.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, EmailStr
|
2 |
+
|
3 |
+
from trauma.api.account.dto import AccessToken
|
4 |
+
from trauma.api.account.model import AccountModel
|
5 |
+
from trauma.core.wrappers import TraumaResponseWrapper
|
6 |
+
|
7 |
+
|
8 |
+
class RegisterAccountRequest(BaseModel):
|
9 |
+
email: EmailStr
|
10 |
+
password: str
|
11 |
+
|
12 |
+
|
13 |
+
class RegisterAccountWrapper(TraumaResponseWrapper[AccountModel]):
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
class LoginAccountRequest(BaseModel):
|
18 |
+
email: EmailStr
|
19 |
+
password: str
|
20 |
+
|
21 |
+
|
22 |
+
class LoginAccountResponse(BaseModel):
|
23 |
+
accessToken: AccessToken
|
24 |
+
account: AccountModel
|
25 |
+
|
26 |
+
|
27 |
+
class LoginAccountWrapper(TraumaResponseWrapper[LoginAccountResponse]):
|
28 |
+
pass
|
trauma/api/security/views.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
|
3 |
+
from trauma.api.account.dto import AccessToken
|
4 |
+
from trauma.api.security import security_router
|
5 |
+
from trauma.api.security.db_requests import authenticate_account, save_account
|
6 |
+
from trauma.api.security.schemas import RegisterAccountRequest, RegisterAccountWrapper, LoginAccountResponse, \
|
7 |
+
LoginAccountWrapper, LoginAccountRequest
|
8 |
+
from trauma.core.security import create_access_token
|
9 |
+
|
10 |
+
model: Literal["accounts"] = "accounts"
|
11 |
+
|
12 |
+
|
13 |
+
@security_router.post('/register')
|
14 |
+
async def register_user(data: RegisterAccountRequest) -> RegisterAccountWrapper:
|
15 |
+
account = await save_account(data)
|
16 |
+
return RegisterAccountWrapper(data=account)
|
17 |
+
|
18 |
+
|
19 |
+
@security_router.post('/login')
|
20 |
+
async def login(data: LoginAccountRequest) -> LoginAccountWrapper:
|
21 |
+
account = await authenticate_account(data)
|
22 |
+
access_token = create_access_token(account.email, str(account.id))
|
23 |
+
response = LoginAccountResponse(
|
24 |
+
accessToken=AccessToken(value=access_token),
|
25 |
+
account=account,
|
26 |
+
)
|
27 |
+
return LoginAccountWrapper(data=response)
|
trauma/core/config.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
from functools import lru_cache
|
4 |
+
|
5 |
+
import motor.motor_asyncio
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from openai import AsyncClient
|
8 |
+
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
class BaseConfig:
|
12 |
+
BASE_DIR: pathlib.Path = pathlib.Path(__file__).parent.parent.parent
|
13 |
+
STATIC_DIR = "static"
|
14 |
+
SECRET_KEY = os.getenv('SECRET')
|
15 |
+
DB_CLIENT = motor.motor_asyncio.AsyncIOMotorClient(os.getenv("MONGO_DB_URL")).trauma
|
16 |
+
OPENAI_CLIENT = AsyncClient(api_key=os.getenv('OPENAI_API_KEY'))
|
17 |
+
|
18 |
+
|
19 |
+
class DevelopmentConfig(BaseConfig):
|
20 |
+
Issuer = "http://localhost:8000"
|
21 |
+
Audience = "http://localhost:3000"
|
22 |
+
|
23 |
+
|
24 |
+
class ProductionConfig(BaseConfig):
|
25 |
+
Issuer = ""
|
26 |
+
Audience = ""
|
27 |
+
|
28 |
+
|
29 |
+
@lru_cache()
|
30 |
+
def get_settings() -> DevelopmentConfig | ProductionConfig:
|
31 |
+
config_cls_dict = {
|
32 |
+
'development': DevelopmentConfig,
|
33 |
+
'production': ProductionConfig,
|
34 |
+
}
|
35 |
+
config_name = os.getenv('FASTAPI_CONFIG', default='development')
|
36 |
+
config_cls = config_cls_dict[config_name]
|
37 |
+
return config_cls()
|
38 |
+
|
39 |
+
|
40 |
+
settings = get_settings()
|
trauma/core/database.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from enum import Enum
|
3 |
+
from typing import Dict, Any, Type
|
4 |
+
|
5 |
+
from bson import ObjectId
|
6 |
+
from pydantic import GetCoreSchemaHandler, BaseModel, Field, AnyUrl
|
7 |
+
from pydantic.json_schema import JsonSchemaValue
|
8 |
+
from pydantic_core import core_schema
|
9 |
+
|
10 |
+
|
11 |
+
class PyObjectId:
|
12 |
+
@classmethod
|
13 |
+
def __get_pydantic_core_schema__(
|
14 |
+
cls, source: type, handler: GetCoreSchemaHandler
|
15 |
+
) -> core_schema.CoreSchema:
|
16 |
+
return core_schema.with_info_after_validator_function(
|
17 |
+
cls.validate, core_schema.str_schema()
|
18 |
+
)
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def __get_pydantic_json_schema__(
|
22 |
+
cls, schema: core_schema.CoreSchema, handler: GetCoreSchemaHandler
|
23 |
+
) -> JsonSchemaValue:
|
24 |
+
return {"type": "string"}
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def validate(cls, value: str) -> ObjectId:
|
28 |
+
if not ObjectId.is_valid(value):
|
29 |
+
raise ValueError(f"Invalid ObjectId: {value}")
|
30 |
+
return ObjectId(value)
|
31 |
+
|
32 |
+
def __getattr__(self, item):
|
33 |
+
return getattr(self.__dict__['value'], item)
|
34 |
+
|
35 |
+
def __init__(self, value: str = None):
|
36 |
+
if value is None:
|
37 |
+
self.value = ObjectId()
|
38 |
+
else:
|
39 |
+
self.value = self.validate(value)
|
40 |
+
|
41 |
+
def __str__(self):
|
42 |
+
return str(self.value)
|
43 |
+
|
44 |
+
|
45 |
+
class MongoBaseModel(BaseModel):
|
46 |
+
id: str = Field(default_factory=lambda: str(PyObjectId()))
|
47 |
+
|
48 |
+
class Config:
|
49 |
+
arbitrary_types_allowed = True
|
50 |
+
|
51 |
+
def to_mongo(self) -> Dict[str, Any]:
|
52 |
+
def model_to_dict(model: BaseModel) -> Dict[str, Any]:
|
53 |
+
doc = {}
|
54 |
+
for name, value in model._iter():
|
55 |
+
key = model.__fields__[name].alias or name
|
56 |
+
|
57 |
+
if isinstance(value, BaseModel):
|
58 |
+
doc[key] = model_to_dict(value)
|
59 |
+
elif isinstance(value, list) and all(isinstance(i, BaseModel) for i in value):
|
60 |
+
doc[key] = [model_to_dict(item) for item in value]
|
61 |
+
elif value and isinstance(value, Enum):
|
62 |
+
doc[key] = value.value
|
63 |
+
elif isinstance(value, datetime):
|
64 |
+
doc[key] = value.isoformat()
|
65 |
+
elif value and isinstance(value, AnyUrl):
|
66 |
+
doc[key] = str(value)
|
67 |
+
else:
|
68 |
+
doc[key] = value
|
69 |
+
|
70 |
+
return doc
|
71 |
+
|
72 |
+
result = model_to_dict(self)
|
73 |
+
return result
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def from_mongo(cls, data: Dict[str, Any]):
|
77 |
+
def restore_enums(inst: Any, model_cls: Type[BaseModel]) -> None:
|
78 |
+
for name, field in model_cls.__fields__.items():
|
79 |
+
value = getattr(inst, name)
|
80 |
+
if field and isinstance(field.annotation, type) and issubclass(field.annotation, Enum):
|
81 |
+
setattr(inst, name, field.annotation(value))
|
82 |
+
elif isinstance(value, BaseModel):
|
83 |
+
restore_enums(value, value.__class__)
|
84 |
+
elif isinstance(value, list):
|
85 |
+
for i, item in enumerate(value):
|
86 |
+
if isinstance(item, BaseModel):
|
87 |
+
restore_enums(item, item.__class__)
|
88 |
+
elif isinstance(field.annotation, type) and issubclass(field.annotation, Enum):
|
89 |
+
value[i] = field.annotation(item)
|
90 |
+
elif isinstance(value, dict):
|
91 |
+
for k, v in value.items():
|
92 |
+
if isinstance(v, BaseModel):
|
93 |
+
restore_enums(v, v.__class__)
|
94 |
+
elif isinstance(field.annotation, type) and issubclass(field.annotation, Enum):
|
95 |
+
value[k] = field.annotation(v)
|
96 |
+
|
97 |
+
if data is None:
|
98 |
+
return None
|
99 |
+
instance = cls(**data)
|
100 |
+
restore_enums(instance, instance.__class__)
|
101 |
+
return instance
|
trauma/core/security.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import timedelta, datetime
|
2 |
+
|
3 |
+
import anyio
|
4 |
+
from fastapi import Depends, HTTPException
|
5 |
+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
6 |
+
from jose import jwt, JWTError
|
7 |
+
from passlib.context import CryptContext
|
8 |
+
|
9 |
+
from trauma.api.account.model import AccountModel
|
10 |
+
from trauma.core.config import settings
|
11 |
+
|
12 |
+
|
13 |
+
def verify_password(plain_password, hashed_password) -> bool:
|
14 |
+
result = CryptContext(schemes=["bcrypt"], deprecated="auto").verify(plain_password, hashed_password)
|
15 |
+
return result
|
16 |
+
|
17 |
+
|
18 |
+
def create_access_token(email: str, account_id: str):
|
19 |
+
payload = {
|
20 |
+
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name": email,
|
21 |
+
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier": account_id,
|
22 |
+
"accountId": account_id,
|
23 |
+
"iss": settings.Issuer,
|
24 |
+
"aud": settings.Audience,
|
25 |
+
"exp": datetime.utcnow() + timedelta(days=30)
|
26 |
+
}
|
27 |
+
encoded_jwt = jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256")
|
28 |
+
return encoded_jwt
|
29 |
+
|
30 |
+
|
31 |
+
class PermissionDependency:
|
32 |
+
def __init__(self, is_public: bool = False):
|
33 |
+
self.is_public = is_public
|
34 |
+
|
35 |
+
def __call__(
|
36 |
+
self, credentials: HTTPAuthorizationCredentials | None = Depends(HTTPBearer(auto_error=False))
|
37 |
+
) -> AccountModel | None:
|
38 |
+
if credentials is None:
|
39 |
+
if self.is_public:
|
40 |
+
return None
|
41 |
+
else:
|
42 |
+
raise HTTPException(status_code=403, detail="Permission denied")
|
43 |
+
try:
|
44 |
+
account_id = self.authenticate_jwt_token(credentials.credentials)
|
45 |
+
account_data = anyio.from_thread.run(self.get_account_by_id, account_id)
|
46 |
+
self.check_account_health(account_data)
|
47 |
+
return account_data
|
48 |
+
|
49 |
+
except JWTError:
|
50 |
+
raise HTTPException(status_code=403, detail="Permission denied")
|
51 |
+
|
52 |
+
async def get_account_by_id(self, account_id: str) -> AccountModel:
|
53 |
+
account = await settings.DB_CLIENT.accounts.find_one({"id": account_id})
|
54 |
+
return AccountModel.from_mongo(account)
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def check_account_health(account: AccountModel):
|
58 |
+
if not account:
|
59 |
+
raise HTTPException(status_code=403, detail="Permission denied")
|
60 |
+
|
61 |
+
def authenticate_jwt_token(self, token: str) -> str:
|
62 |
+
payload = jwt.decode(token,
|
63 |
+
settings.SECRET_KEY,
|
64 |
+
algorithms="HS256",
|
65 |
+
audience=settings.Audience)
|
66 |
+
email: str = payload.get("http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name")
|
67 |
+
account_id = payload.get("accountId")
|
68 |
+
|
69 |
+
if email is None or account_id is None:
|
70 |
+
raise HTTPException(status_code=403, detail="Permission denied")
|
71 |
+
|
72 |
+
return account_id
|
trauma/core/wrappers.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import wraps
|
2 |
+
from typing import Generic, Optional, TypeVar
|
3 |
+
|
4 |
+
from fastapi import HTTPException
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from starlette.responses import JSONResponse
|
7 |
+
|
8 |
+
from trauma.core.config import settings
|
9 |
+
|
10 |
+
T = TypeVar('T')
|
11 |
+
|
12 |
+
|
13 |
+
class ErrorTraumaResponse(BaseModel):
|
14 |
+
message: str
|
15 |
+
|
16 |
+
|
17 |
+
class TraumaResponseWrapper(BaseModel, Generic[T]):
|
18 |
+
data: Optional[T] = None
|
19 |
+
successful: bool = True
|
20 |
+
error: Optional[ErrorTraumaResponse] = None
|
21 |
+
|
22 |
+
def response(self, status_code: int):
|
23 |
+
return JSONResponse(
|
24 |
+
status_code=status_code,
|
25 |
+
content={
|
26 |
+
"data": self.data,
|
27 |
+
"successful": self.successful,
|
28 |
+
"error": self.error.dict() if self.error else None
|
29 |
+
}
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
def exception_wrapper(http_error: int, error_message: str):
|
34 |
+
def decorator(func):
|
35 |
+
@wraps(func)
|
36 |
+
async def wrapper(*args, **kwargs):
|
37 |
+
try:
|
38 |
+
return await func(*args, **kwargs)
|
39 |
+
except Exception as e:
|
40 |
+
raise HTTPException(status_code=http_error, detail=error_message) from e
|
41 |
+
|
42 |
+
return wrapper
|
43 |
+
|
44 |
+
return decorator
|