Spaces:
Running
Running
add auth
Browse files- trauma/__init__.py +5 -0
- trauma/api/account/__init__.py +7 -0
- trauma/api/account/db_requests.py +17 -0
- trauma/api/account/dto.py +13 -0
- trauma/api/account/model.py +34 -0
- trauma/api/account/schemas.py +14 -0
- trauma/api/account/views.py +33 -0
- trauma/api/chat/db_requests.py +15 -7
- trauma/api/chat/model.py +2 -0
- trauma/api/chat/views.py +19 -9
- trauma/api/common/db_requests.py +17 -0
- trauma/api/message/db_requests.py +5 -1
- trauma/api/message/views.py +10 -3
- trauma/api/security/__init__.py +7 -0
- trauma/api/security/db_requests.py +32 -0
- trauma/api/security/schemas.py +28 -0
- trauma/api/security/views.py +26 -0
- trauma/core/security.py +12 -13
trauma/__init__.py
CHANGED
@@ -12,12 +12,17 @@ from trauma.core.wrappers import TraumaResponseWrapper, ErrorTraumaResponse
|
|
12 |
def create_app() -> FastAPI:
|
13 |
app = FastAPI()
|
14 |
|
|
|
|
|
|
|
15 |
from trauma.api.chat import chat_router
|
16 |
app.include_router(chat_router, tags=['chat'])
|
17 |
|
18 |
from trauma.api.message import message_router
|
19 |
app.include_router(message_router, tags=['message'])
|
20 |
|
|
|
|
|
21 |
|
22 |
app.add_middleware(
|
23 |
CORSMiddleware,
|
|
|
12 |
def create_app() -> FastAPI:
|
13 |
app = FastAPI()
|
14 |
|
15 |
+
from trauma.api.account import account_router
|
16 |
+
app.include_router(account_router, tags=["account"])
|
17 |
+
|
18 |
from trauma.api.chat import chat_router
|
19 |
app.include_router(chat_router, tags=['chat'])
|
20 |
|
21 |
from trauma.api.message import message_router
|
22 |
app.include_router(message_router, tags=['message'])
|
23 |
|
24 |
+
from trauma.api.security import security_router
|
25 |
+
app.include_router(security_router, tags=['security'])
|
26 |
|
27 |
app.add_middleware(
|
28 |
CORSMiddleware,
|
trauma/api/account/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
+
|
3 |
+
account_router = APIRouter(
|
4 |
+
prefix='/api/account'
|
5 |
+
)
|
6 |
+
|
7 |
+
from . import views
|
trauma/api/account/db_requests.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
from trauma.api.account.model import AccountModel
|
4 |
+
from trauma.core.config import settings
|
5 |
+
|
6 |
+
|
7 |
+
async def get_all_model_obj(page_size: int, page_index: int) -> tuple[list[AccountModel], int]:
|
8 |
+
skip = page_size * page_index
|
9 |
+
objects, total_count = await asyncio.gather(
|
10 |
+
settings.DB_CLIENT.accounts
|
11 |
+
.find()
|
12 |
+
.skip(skip)
|
13 |
+
.limit(page_size)
|
14 |
+
.to_list(length=page_size),
|
15 |
+
settings.DB_CLIENT.accounts.count_documents({})
|
16 |
+
)
|
17 |
+
return objects, total_count
|
trauma/api/account/dto.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class AccessToken(BaseModel):
|
7 |
+
type: str = "Bearer"
|
8 |
+
value: str
|
9 |
+
|
10 |
+
|
11 |
+
class AccountType(Enum):
|
12 |
+
User = 1
|
13 |
+
Admin = 2
|
trauma/api/account/model.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
|
3 |
+
from passlib.context import CryptContext
|
4 |
+
from pydantic import field_validator, Field, EmailStr
|
5 |
+
|
6 |
+
from trauma.api.account.dto import AccountType
|
7 |
+
from trauma.core.database import MongoBaseModel
|
8 |
+
|
9 |
+
|
10 |
+
class AccountModel(MongoBaseModel):
|
11 |
+
email: EmailStr
|
12 |
+
password: str | None = Field(exclude=True, default=None)
|
13 |
+
accountType: AccountType = AccountType.User
|
14 |
+
|
15 |
+
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
16 |
+
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
17 |
+
|
18 |
+
@field_validator("datetimeUpdated", mode="before", check_fields=False)
|
19 |
+
@classmethod
|
20 |
+
def validate_datetimeUpdated(cls, v):
|
21 |
+
return v or datetime.now()
|
22 |
+
|
23 |
+
@field_validator('password', mode='before', check_fields=False)
|
24 |
+
@classmethod
|
25 |
+
def set_password_hash(cls, v):
|
26 |
+
if not v.startswith("$2b$"):
|
27 |
+
return CryptContext(schemes=["bcrypt"], deprecated="auto").hash(v)
|
28 |
+
return v
|
29 |
+
|
30 |
+
class Config:
|
31 |
+
arbitrary_types_allowed = True
|
32 |
+
json_encoders = {
|
33 |
+
datetime: lambda dt: dt.isoformat()
|
34 |
+
}
|
trauma/api/account/schemas.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
from trauma.api.account.model import AccountModel
|
4 |
+
from trauma.api.common.dto import Paging
|
5 |
+
from trauma.core.wrappers import TraumaResponseWrapper
|
6 |
+
|
7 |
+
|
8 |
+
class AccountWrapper(TraumaResponseWrapper[AccountModel]):
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
class AllAccountsResponse(BaseModel):
|
13 |
+
paging: Paging
|
14 |
+
data: list[AccountModel]
|
trauma/api/account/views.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from fastapi import Depends, Query
|
4 |
+
|
5 |
+
from trauma.api.account import account_router
|
6 |
+
from trauma.api.account.db_requests import get_all_model_obj
|
7 |
+
from trauma.api.account.dto import AccountType
|
8 |
+
from trauma.api.account.model import AccountModel
|
9 |
+
from trauma.api.account.schemas import AccountWrapper, AllAccountsResponse
|
10 |
+
from trauma.api.common.dto import Paging
|
11 |
+
from trauma.core.security import PermissionDependency
|
12 |
+
from trauma.core.wrappers import TraumaResponseWrapper
|
13 |
+
|
14 |
+
|
15 |
+
@account_router.get('/all')
|
16 |
+
async def get_all_accounts(
|
17 |
+
pageSize: Optional[int] = Query(10, description="Number of countries to return per page"),
|
18 |
+
pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
|
19 |
+
_: AccountModel = Depends(PermissionDependency([AccountType.Admin]))
|
20 |
+
) -> TraumaResponseWrapper[AllAccountsResponse]:
|
21 |
+
countries, total_count = await get_all_model_obj(pageSize, pageIndex)
|
22 |
+
response = AllAccountsResponse(
|
23 |
+
paging=Paging(pageSize=pageSize, pageIndex=pageIndex, totalCount=total_count),
|
24 |
+
data=countries
|
25 |
+
)
|
26 |
+
return TraumaResponseWrapper(data=response)
|
27 |
+
|
28 |
+
|
29 |
+
@account_router.get('')
|
30 |
+
async def get_account(
|
31 |
+
account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
|
32 |
+
) -> AccountWrapper:
|
33 |
+
return AccountWrapper(data=account)
|
trauma/api/chat/db_requests.py
CHANGED
@@ -2,6 +2,8 @@ import asyncio
|
|
2 |
|
3 |
from fastapi import HTTPException
|
4 |
|
|
|
|
|
5 |
from trauma.api.chat.model import ChatModel
|
6 |
from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest
|
7 |
from trauma.api.message.dto import Author
|
@@ -9,37 +11,43 @@ from trauma.api.message.model import MessageModel
|
|
9 |
from trauma.core.config import settings
|
10 |
|
11 |
|
12 |
-
async def get_chat_obj(chat_id: str) -> ChatModel:
|
13 |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
14 |
if not chat:
|
15 |
raise HTTPException(status_code=404, detail="Chat not found")
|
16 |
chat = ChatModel.from_mongo(chat)
|
|
|
|
|
17 |
return chat
|
18 |
|
19 |
|
20 |
-
async def create_chat_obj(chat_request: CreateChatRequest) -> ChatModel:
|
21 |
-
chat = ChatModel(model=chat_request.model)
|
22 |
await settings.DB_CLIENT.chats.insert_one(chat.to_mongo())
|
23 |
return chat
|
24 |
|
25 |
|
26 |
-
async def delete_chat_obj(chat_id: str) -> None:
|
27 |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
28 |
if not chat:
|
29 |
raise HTTPException(status_code=404, detail="Chat not found")
|
30 |
chat = ChatModel.from_mongo(chat)
|
|
|
|
|
31 |
await settings.DB_CLIENT.chats.delete_one({"id": chat_id})
|
32 |
|
33 |
|
34 |
-
async def update_chat_obj_title(
|
35 |
-
chat = await settings.DB_CLIENT.chats.find_one({"id":
|
36 |
if not chat:
|
37 |
raise HTTPException(status_code=404, detail="Chat not found")
|
38 |
|
39 |
chat = ChatModel.from_mongo(chat)
|
|
|
|
|
40 |
|
41 |
chat.title = chat_request.title
|
42 |
-
await settings.DB_CLIENT.chats.update_one({"id":
|
43 |
return chat
|
44 |
|
45 |
|
|
|
2 |
|
3 |
from fastapi import HTTPException
|
4 |
|
5 |
+
from trauma.api.account.dto import AccountType
|
6 |
+
from trauma.api.account.model import AccountModel
|
7 |
from trauma.api.chat.model import ChatModel
|
8 |
from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest
|
9 |
from trauma.api.message.dto import Author
|
|
|
11 |
from trauma.core.config import settings
|
12 |
|
13 |
|
14 |
+
async def get_chat_obj(chat_id: str, account: AccountModel) -> ChatModel:
|
15 |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
16 |
if not chat:
|
17 |
raise HTTPException(status_code=404, detail="Chat not found")
|
18 |
chat = ChatModel.from_mongo(chat)
|
19 |
+
if chat.account.id != account.id and account.accountType != AccountType.Admin:
|
20 |
+
raise HTTPException(status_code=403, detail="Not authorized")
|
21 |
return chat
|
22 |
|
23 |
|
24 |
+
async def create_chat_obj(chat_request: CreateChatRequest, account: AccountModel) -> ChatModel:
|
25 |
+
chat = ChatModel(model=chat_request.model, account=account)
|
26 |
await settings.DB_CLIENT.chats.insert_one(chat.to_mongo())
|
27 |
return chat
|
28 |
|
29 |
|
30 |
+
async def delete_chat_obj(chat_id: str, account: AccountModel) -> None:
|
31 |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
32 |
if not chat:
|
33 |
raise HTTPException(status_code=404, detail="Chat not found")
|
34 |
chat = ChatModel.from_mongo(chat)
|
35 |
+
if chat.account.id != account.id and account.accountType != AccountType.Admin:
|
36 |
+
raise HTTPException(status_code=403, detail="Not authorized")
|
37 |
await settings.DB_CLIENT.chats.delete_one({"id": chat_id})
|
38 |
|
39 |
|
40 |
+
async def update_chat_obj_title(chat_id: str, chat_request: ChatTitleRequest, account: AccountModel) -> ChatModel:
|
41 |
+
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
42 |
if not chat:
|
43 |
raise HTTPException(status_code=404, detail="Chat not found")
|
44 |
|
45 |
chat = ChatModel.from_mongo(chat)
|
46 |
+
if chat.account.id != account.id and account.accountType != AccountType.Admin:
|
47 |
+
raise HTTPException(status_code=403, detail="Not authorized")
|
48 |
|
49 |
chat.title = chat_request.title
|
50 |
+
await settings.DB_CLIENT.chats.update_one({"id": chat_id}, {"$set": chat.to_mongo()})
|
51 |
return chat
|
52 |
|
53 |
|
trauma/api/chat/model.py
CHANGED
@@ -2,6 +2,7 @@ from datetime import datetime
|
|
2 |
|
3 |
from pydantic import Field
|
4 |
|
|
|
5 |
from trauma.api.chat.dto import ModelType, EntityData
|
6 |
from trauma.core.database import MongoBaseModel
|
7 |
|
@@ -10,5 +11,6 @@ class ChatModel(MongoBaseModel):
|
|
10 |
title: str = 'New Chat'
|
11 |
model: ModelType
|
12 |
entityData: EntityData = EntityData()
|
|
|
13 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
14 |
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
|
|
2 |
|
3 |
from pydantic import Field
|
4 |
|
5 |
+
from trauma.api.account.model import AccountModel
|
6 |
from trauma.api.chat.dto import ModelType, EntityData
|
7 |
from trauma.core.database import MongoBaseModel
|
8 |
|
|
|
11 |
title: str = 'New Chat'
|
12 |
model: ModelType
|
13 |
entityData: EntityData = EntityData()
|
14 |
+
account: AccountModel
|
15 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
16 |
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
trauma/api/chat/views.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
from typing import Optional
|
2 |
|
3 |
-
from fastapi import Query
|
4 |
|
|
|
|
|
5 |
from trauma.api.chat import chat_router
|
6 |
from trauma.api.chat.db_requests import (get_chat_obj,
|
7 |
create_chat_obj,
|
@@ -10,6 +12,7 @@ from trauma.api.chat.db_requests import (get_chat_obj,
|
|
10 |
get_all_chats_obj, save_intro_message)
|
11 |
from trauma.api.chat.schemas import ChatWrapper, AllChatWrapper, CreateChatRequest, AllChatResponse, ChatTitleRequest
|
12 |
from trauma.api.common.dto import Paging
|
|
|
13 |
from trauma.core.wrappers import TraumaResponseWrapper
|
14 |
|
15 |
|
@@ -17,6 +20,7 @@ from trauma.core.wrappers import TraumaResponseWrapper
|
|
17 |
async def get_all_chats(
|
18 |
pageSize: Optional[int] = Query(10, description="Number of objects to return per page"),
|
19 |
pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
|
|
|
20 |
) -> AllChatWrapper:
|
21 |
chats, total_count = await get_all_chats_obj(pageSize, pageIndex)
|
22 |
response = AllChatResponse(
|
@@ -28,30 +32,36 @@ async def get_all_chats(
|
|
28 |
|
29 |
@chat_router.get('/{chatId}')
|
30 |
async def get_chat(
|
31 |
-
chatId: str
|
|
|
32 |
) -> ChatWrapper:
|
33 |
-
chat = await get_chat_obj(chatId)
|
34 |
return ChatWrapper(data=chat)
|
35 |
|
36 |
|
37 |
@chat_router.delete('/{chatId}')
|
38 |
-
async def delete_chat(
|
39 |
-
|
|
|
|
|
|
|
40 |
return TraumaResponseWrapper()
|
41 |
|
42 |
|
43 |
@chat_router.patch('/{chatId}/title')
|
44 |
async def update_chat_title(
|
45 |
-
chatId: str, chat: ChatTitleRequest
|
|
|
46 |
) -> ChatWrapper:
|
47 |
-
chat = await update_chat_obj_title(chatId, chat)
|
48 |
return ChatWrapper(data=chat)
|
49 |
|
50 |
|
51 |
@chat_router.post('')
|
52 |
async def create_chat(
|
53 |
-
chat_data: CreateChatRequest
|
|
|
54 |
) -> ChatWrapper:
|
55 |
-
chat = await create_chat_obj(chat_data)
|
56 |
await save_intro_message(chat.id)
|
57 |
return ChatWrapper(data=chat)
|
|
|
1 |
from typing import Optional
|
2 |
|
3 |
+
from fastapi import Query, Depends
|
4 |
|
5 |
+
from trauma.api.account.dto import AccountType
|
6 |
+
from trauma.api.account.model import AccountModel
|
7 |
from trauma.api.chat import chat_router
|
8 |
from trauma.api.chat.db_requests import (get_chat_obj,
|
9 |
create_chat_obj,
|
|
|
12 |
get_all_chats_obj, save_intro_message)
|
13 |
from trauma.api.chat.schemas import ChatWrapper, AllChatWrapper, CreateChatRequest, AllChatResponse, ChatTitleRequest
|
14 |
from trauma.api.common.dto import Paging
|
15 |
+
from trauma.core.security import PermissionDependency
|
16 |
from trauma.core.wrappers import TraumaResponseWrapper
|
17 |
|
18 |
|
|
|
20 |
async def get_all_chats(
|
21 |
pageSize: Optional[int] = Query(10, description="Number of objects to return per page"),
|
22 |
pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
|
23 |
+
_: AccountModel = Depends(PermissionDependency([AccountType.Admin]))
|
24 |
) -> AllChatWrapper:
|
25 |
chats, total_count = await get_all_chats_obj(pageSize, pageIndex)
|
26 |
response = AllChatResponse(
|
|
|
32 |
|
33 |
@chat_router.get('/{chatId}')
|
34 |
async def get_chat(
|
35 |
+
chatId: str,
|
36 |
+
account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
|
37 |
) -> ChatWrapper:
|
38 |
+
chat = await get_chat_obj(chatId, account)
|
39 |
return ChatWrapper(data=chat)
|
40 |
|
41 |
|
42 |
@chat_router.delete('/{chatId}')
|
43 |
+
async def delete_chat(
|
44 |
+
chatId: str,
|
45 |
+
account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
|
46 |
+
) -> TraumaResponseWrapper:
|
47 |
+
await delete_chat_obj(chatId, account)
|
48 |
return TraumaResponseWrapper()
|
49 |
|
50 |
|
51 |
@chat_router.patch('/{chatId}/title')
|
52 |
async def update_chat_title(
|
53 |
+
chatId: str, chat: ChatTitleRequest,
|
54 |
+
account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
|
55 |
) -> ChatWrapper:
|
56 |
+
chat = await update_chat_obj_title(chatId, chat, account)
|
57 |
return ChatWrapper(data=chat)
|
58 |
|
59 |
|
60 |
@chat_router.post('')
|
61 |
async def create_chat(
|
62 |
+
chat_data: CreateChatRequest,
|
63 |
+
account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
|
64 |
) -> ChatWrapper:
|
65 |
+
chat = await create_chat_obj(chat_data, account)
|
66 |
await save_intro_message(chat.id)
|
67 |
return ChatWrapper(data=chat)
|
trauma/api/common/db_requests.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import HTTPException
|
2 |
+
from pydantic import EmailStr
|
3 |
+
|
4 |
+
from trauma.core.config import settings
|
5 |
+
|
6 |
+
|
7 |
+
async def check_unique_fields_existence(name: str,
|
8 |
+
new_value: EmailStr | str,
|
9 |
+
current_value: str | None = None) -> None:
|
10 |
+
if new_value == current_value or not new_value:
|
11 |
+
return
|
12 |
+
account = await settings.DB_CLIENT.accounts.find_one(
|
13 |
+
{name: str(new_value)},
|
14 |
+
collation={"locale": "en", "strength": 2}
|
15 |
+
)
|
16 |
+
if account:
|
17 |
+
raise HTTPException(status_code=400, detail=f'Account with specified {name} already exists.')
|
trauma/api/message/db_requests.py
CHANGED
@@ -2,6 +2,8 @@ import asyncio
|
|
2 |
|
3 |
from fastapi import HTTPException
|
4 |
|
|
|
|
|
5 |
from trauma.api.chat.model import ChatModel
|
6 |
from trauma.api.data.model import EntityModel
|
7 |
from trauma.api.message.dto import Author
|
@@ -24,7 +26,7 @@ async def create_message_obj(
|
|
24 |
return message, chat
|
25 |
|
26 |
|
27 |
-
async def get_all_chat_messages_obj(chat_id: str) -> tuple[list[MessageModel], ChatModel]:
|
28 |
messages, chat = await asyncio.gather(
|
29 |
settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None),
|
30 |
settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
@@ -35,6 +37,8 @@ async def get_all_chat_messages_obj(chat_id: str) -> tuple[list[MessageModel], C
|
|
35 |
raise HTTPException(status_code=404, detail="Chat not found")
|
36 |
|
37 |
chat = ChatModel.from_mongo(chat)
|
|
|
|
|
38 |
return messages, chat
|
39 |
|
40 |
|
|
|
2 |
|
3 |
from fastapi import HTTPException
|
4 |
|
5 |
+
from trauma.api.account.dto import AccountType
|
6 |
+
from trauma.api.account.model import AccountModel
|
7 |
from trauma.api.chat.model import ChatModel
|
8 |
from trauma.api.data.model import EntityModel
|
9 |
from trauma.api.message.dto import Author
|
|
|
26 |
return message, chat
|
27 |
|
28 |
|
29 |
+
async def get_all_chat_messages_obj(chat_id: str, account: AccountModel) -> tuple[list[MessageModel], ChatModel]:
|
30 |
messages, chat = await asyncio.gather(
|
31 |
settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None),
|
32 |
settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
|
|
37 |
raise HTTPException(status_code=404, detail="Chat not found")
|
38 |
|
39 |
chat = ChatModel.from_mongo(chat)
|
40 |
+
if chat.account.id != account.id and account.accountType != AccountType.Admin:
|
41 |
+
raise HTTPException(status_code=404, detail="Not Authorized.")
|
42 |
return messages, chat
|
43 |
|
44 |
|
trauma/api/message/views.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
from trauma.api.common.dto import Paging
|
2 |
from trauma.api.message import message_router
|
3 |
from trauma.api.message.ai.engine import search_entities
|
@@ -7,14 +11,16 @@ from trauma.api.message.schemas import (AllMessageWrapper,
|
|
7 |
CreateMessageRequest,
|
8 |
CreateMessageResponse)
|
9 |
from trauma.api.message.utils import transform_messages_to_openai
|
|
|
10 |
from trauma.core.wrappers import TraumaResponseWrapper
|
11 |
|
12 |
|
13 |
@message_router.get('/{chatId}/all')
|
14 |
async def get_all_chat_messages(
|
15 |
-
chatId: str
|
|
|
16 |
) -> AllMessageWrapper:
|
17 |
-
messages, _ = await get_all_chat_messages_obj(chatId)
|
18 |
response = AllMessageResponse(
|
19 |
paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
|
20 |
data=messages
|
@@ -26,8 +32,9 @@ async def get_all_chat_messages(
|
|
26 |
async def create_message(
|
27 |
chatId: str,
|
28 |
message_data: CreateMessageRequest,
|
|
|
29 |
) -> TraumaResponseWrapper[CreateMessageResponse]:
|
30 |
-
messages, chat = await get_all_chat_messages_obj(chatId)
|
31 |
message_history = transform_messages_to_openai(messages)
|
32 |
response = await search_entities(message_data.text, message_history, chat)
|
33 |
return TraumaResponseWrapper(data=response)
|
|
|
1 |
+
from fastapi import Depends
|
2 |
+
|
3 |
+
from trauma.api.account.dto import AccountType
|
4 |
+
from trauma.api.account.model import AccountModel
|
5 |
from trauma.api.common.dto import Paging
|
6 |
from trauma.api.message import message_router
|
7 |
from trauma.api.message.ai.engine import search_entities
|
|
|
11 |
CreateMessageRequest,
|
12 |
CreateMessageResponse)
|
13 |
from trauma.api.message.utils import transform_messages_to_openai
|
14 |
+
from trauma.core.security import PermissionDependency
|
15 |
from trauma.core.wrappers import TraumaResponseWrapper
|
16 |
|
17 |
|
18 |
@message_router.get('/{chatId}/all')
|
19 |
async def get_all_chat_messages(
|
20 |
+
chatId: str,
|
21 |
+
account: AccountModel = Depends(PermissionDependency([AccountType.Admin]))
|
22 |
) -> AllMessageWrapper:
|
23 |
+
messages, _ = await get_all_chat_messages_obj(chatId, account)
|
24 |
response = AllMessageResponse(
|
25 |
paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
|
26 |
data=messages
|
|
|
32 |
async def create_message(
|
33 |
chatId: str,
|
34 |
message_data: CreateMessageRequest,
|
35 |
+
account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
|
36 |
) -> TraumaResponseWrapper[CreateMessageResponse]:
|
37 |
+
messages, chat = await get_all_chat_messages_obj(chatId, account)
|
38 |
message_history = transform_messages_to_openai(messages)
|
39 |
response = await search_entities(message_data.text, message_history, chat)
|
40 |
return TraumaResponseWrapper(data=response)
|
trauma/api/security/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
+
|
3 |
+
security_router = APIRouter(
|
4 |
+
prefix='/api/security'
|
5 |
+
)
|
6 |
+
|
7 |
+
from . import views
|
trauma/api/security/db_requests.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
from fastapi import HTTPException
|
4 |
+
|
5 |
+
from trauma.api.account.model import AccountModel
|
6 |
+
from trauma.api.common.db_requests import check_unique_fields_existence
|
7 |
+
from trauma.api.security.schemas import RegisterAccountRequest, LoginAccountRequest
|
8 |
+
from trauma.core.config import settings
|
9 |
+
from trauma.core.security import verify_password
|
10 |
+
|
11 |
+
|
12 |
+
async def save_account(data: RegisterAccountRequest) -> AccountModel:
|
13 |
+
await check_unique_fields_existence("email", data.email)
|
14 |
+
account = AccountModel(
|
15 |
+
**data.model_dump()
|
16 |
+
)
|
17 |
+
await settings.DB_CLIENT.accounts.insert_one(account.to_mongo())
|
18 |
+
return account
|
19 |
+
|
20 |
+
|
21 |
+
async def authenticate_account(data: LoginAccountRequest) -> AccountModel:
|
22 |
+
account = await settings.DB_CLIENT.accounts.find_one(
|
23 |
+
{"email": data.email},
|
24 |
+
collation={"locale": "en", "strength": 2})
|
25 |
+
if account is None:
|
26 |
+
raise HTTPException(status_code=404, detail="Invalid email or password.")
|
27 |
+
|
28 |
+
account = AccountModel.from_mongo(account)
|
29 |
+
|
30 |
+
if not verify_password(data.password, account.password):
|
31 |
+
raise HTTPException(status_code=401, detail="Invalid email or password.")
|
32 |
+
return account
|
trauma/api/security/schemas.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, EmailStr
|
2 |
+
|
3 |
+
from trauma.api.account.dto import AccessToken
|
4 |
+
from trauma.api.account.model import AccountModel
|
5 |
+
from trauma.core.wrappers import TraumaResponseWrapper
|
6 |
+
|
7 |
+
|
8 |
+
class RegisterAccountRequest(BaseModel):
|
9 |
+
email: EmailStr
|
10 |
+
password: str
|
11 |
+
|
12 |
+
|
13 |
+
class RegisterAccountWrapper(TraumaResponseWrapper[AccountModel]):
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
class LoginAccountRequest(BaseModel):
|
18 |
+
email: EmailStr
|
19 |
+
password: str
|
20 |
+
|
21 |
+
|
22 |
+
class LoginAccountResponse(BaseModel):
|
23 |
+
accessToken: AccessToken
|
24 |
+
account: AccountModel
|
25 |
+
|
26 |
+
|
27 |
+
class LoginAccountWrapper(TraumaResponseWrapper[LoginAccountResponse]):
|
28 |
+
pass
|
trauma/api/security/views.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from trauma.api.account.dto import AccessToken
|
2 |
+
from trauma.api.security import security_router
|
3 |
+
from trauma.api.security.db_requests import authenticate_account, save_account
|
4 |
+
from trauma.api.security.schemas import (RegisterAccountRequest,
|
5 |
+
RegisterAccountWrapper,
|
6 |
+
LoginAccountResponse,
|
7 |
+
LoginAccountWrapper,
|
8 |
+
LoginAccountRequest)
|
9 |
+
from trauma.core.security import create_access_token
|
10 |
+
|
11 |
+
|
12 |
+
@security_router.post('/register')
|
13 |
+
async def register_user(data: RegisterAccountRequest) -> RegisterAccountWrapper:
|
14 |
+
account = await save_account(data)
|
15 |
+
return RegisterAccountWrapper(data=account)
|
16 |
+
|
17 |
+
|
18 |
+
@security_router.post('/login')
|
19 |
+
async def login(data: LoginAccountRequest) -> LoginAccountWrapper:
|
20 |
+
account = await authenticate_account(data)
|
21 |
+
access_token = create_access_token(account.email, str(account.id))
|
22 |
+
response = LoginAccountResponse(
|
23 |
+
accessToken=AccessToken(value=access_token),
|
24 |
+
account=account,
|
25 |
+
)
|
26 |
+
return LoginAccountWrapper(data=response)
|
trauma/core/security.py
CHANGED
@@ -6,6 +6,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
6 |
from jose import jwt, JWTError
|
7 |
from passlib.context import CryptContext
|
8 |
|
|
|
9 |
from trauma.api.account.model import AccountModel
|
10 |
from trauma.core.config import settings
|
11 |
|
@@ -29,36 +30,34 @@ def create_access_token(email: str, account_id: str):
|
|
29 |
|
30 |
|
31 |
class PermissionDependency:
|
32 |
-
def __init__(self,
|
33 |
-
self.
|
34 |
|
35 |
def __call__(
|
36 |
self, credentials: HTTPAuthorizationCredentials | None = Depends(HTTPBearer(auto_error=False))
|
37 |
) -> AccountModel | None:
|
38 |
-
if credentials is None:
|
39 |
-
if self.is_public:
|
40 |
-
return None
|
41 |
-
else:
|
42 |
-
raise HTTPException(status_code=403, detail="Permission denied")
|
43 |
try:
|
44 |
account_id = self.authenticate_jwt_token(credentials.credentials)
|
45 |
account_data = anyio.from_thread.run(self.get_account_by_id, account_id)
|
46 |
self.check_account_health(account_data)
|
47 |
-
return account_data
|
48 |
|
49 |
except JWTError:
|
50 |
raise HTTPException(status_code=403, detail="Permission denied")
|
51 |
|
52 |
-
|
|
|
53 |
account = await settings.DB_CLIENT.accounts.find_one({"id": account_id})
|
54 |
-
return
|
55 |
|
56 |
-
|
57 |
-
def check_account_health(account: AccountModel):
|
58 |
if not account:
|
59 |
raise HTTPException(status_code=403, detail="Permission denied")
|
|
|
|
|
60 |
|
61 |
-
|
|
|
62 |
payload = jwt.decode(token,
|
63 |
settings.SECRET_KEY,
|
64 |
algorithms="HS256",
|
|
|
6 |
from jose import jwt, JWTError
|
7 |
from passlib.context import CryptContext
|
8 |
|
9 |
+
from trauma.api.account.dto import AccountType
|
10 |
from trauma.api.account.model import AccountModel
|
11 |
from trauma.core.config import settings
|
12 |
|
|
|
30 |
|
31 |
|
32 |
class PermissionDependency:
|
33 |
+
def __init__(self, account_types: list[AccountType]):
|
34 |
+
self.account_types = account_types
|
35 |
|
36 |
def __call__(
|
37 |
self, credentials: HTTPAuthorizationCredentials | None = Depends(HTTPBearer(auto_error=False))
|
38 |
) -> AccountModel | None:
|
|
|
|
|
|
|
|
|
|
|
39 |
try:
|
40 |
account_id = self.authenticate_jwt_token(credentials.credentials)
|
41 |
account_data = anyio.from_thread.run(self.get_account_by_id, account_id)
|
42 |
self.check_account_health(account_data)
|
43 |
+
return AccountModel.from_mongo(account_data)
|
44 |
|
45 |
except JWTError:
|
46 |
raise HTTPException(status_code=403, detail="Permission denied")
|
47 |
|
48 |
+
@staticmethod
|
49 |
+
async def get_account_by_id(account_id: str) -> dict:
|
50 |
account = await settings.DB_CLIENT.accounts.find_one({"id": account_id})
|
51 |
+
return account
|
52 |
|
53 |
+
def check_account_health(self, account: dict):
|
|
|
54 |
if not account:
|
55 |
raise HTTPException(status_code=403, detail="Permission denied")
|
56 |
+
if account['accountType'] not in [i.value for i in self.account_types]:
|
57 |
+
raise HTTPException(status_code=403, detail="Permission denied")
|
58 |
|
59 |
+
@staticmethod
|
60 |
+
def authenticate_jwt_token(token: str) -> str:
|
61 |
payload = jwt.decode(token,
|
62 |
settings.SECRET_KEY,
|
63 |
algorithms="HS256",
|