brestok commited on
Commit
c6cc0f2
·
1 Parent(s): b695fbe
trauma/__init__.py CHANGED
@@ -12,12 +12,17 @@ from trauma.core.wrappers import TraumaResponseWrapper, ErrorTraumaResponse
12
  def create_app() -> FastAPI:
13
  app = FastAPI()
14
 
 
 
 
15
  from trauma.api.chat import chat_router
16
  app.include_router(chat_router, tags=['chat'])
17
 
18
  from trauma.api.message import message_router
19
  app.include_router(message_router, tags=['message'])
20
 
 
 
21
 
22
  app.add_middleware(
23
  CORSMiddleware,
 
12
  def create_app() -> FastAPI:
13
  app = FastAPI()
14
 
15
+ from trauma.api.account import account_router
16
+ app.include_router(account_router, tags=["account"])
17
+
18
  from trauma.api.chat import chat_router
19
  app.include_router(chat_router, tags=['chat'])
20
 
21
  from trauma.api.message import message_router
22
  app.include_router(message_router, tags=['message'])
23
 
24
+ from trauma.api.security import security_router
25
+ app.include_router(security_router, tags=['security'])
26
 
27
  app.add_middleware(
28
  CORSMiddleware,
trauma/api/account/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+
3
+ account_router = APIRouter(
4
+ prefix='/api/account'
5
+ )
6
+
7
+ from . import views
trauma/api/account/db_requests.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ from trauma.api.account.model import AccountModel
4
+ from trauma.core.config import settings
5
+
6
+
7
+ async def get_all_model_obj(page_size: int, page_index: int) -> tuple[list[AccountModel], int]:
8
+ skip = page_size * page_index
9
+ objects, total_count = await asyncio.gather(
10
+ settings.DB_CLIENT.accounts
11
+ .find()
12
+ .skip(skip)
13
+ .limit(page_size)
14
+ .to_list(length=page_size),
15
+ settings.DB_CLIENT.accounts.count_documents({})
16
+ )
17
+ return objects, total_count
trauma/api/account/dto.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class AccessToken(BaseModel):
7
+ type: str = "Bearer"
8
+ value: str
9
+
10
+
11
+ class AccountType(Enum):
12
+ User = 1
13
+ Admin = 2
trauma/api/account/model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ from passlib.context import CryptContext
4
+ from pydantic import field_validator, Field, EmailStr
5
+
6
+ from trauma.api.account.dto import AccountType
7
+ from trauma.core.database import MongoBaseModel
8
+
9
+
10
+ class AccountModel(MongoBaseModel):
11
+ email: EmailStr
12
+ password: str | None = Field(exclude=True, default=None)
13
+ accountType: AccountType = AccountType.User
14
+
15
+ datetimeInserted: datetime = Field(default_factory=datetime.now)
16
+ datetimeUpdated: datetime = Field(default_factory=datetime.now)
17
+
18
+ @field_validator("datetimeUpdated", mode="before", check_fields=False)
19
+ @classmethod
20
+ def validate_datetimeUpdated(cls, v):
21
+ return v or datetime.now()
22
+
23
+ @field_validator('password', mode='before', check_fields=False)
24
+ @classmethod
25
+ def set_password_hash(cls, v):
26
+ if not v.startswith("$2b$"):
27
+ return CryptContext(schemes=["bcrypt"], deprecated="auto").hash(v)
28
+ return v
29
+
30
+ class Config:
31
+ arbitrary_types_allowed = True
32
+ json_encoders = {
33
+ datetime: lambda dt: dt.isoformat()
34
+ }
trauma/api/account/schemas.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ from trauma.api.account.model import AccountModel
4
+ from trauma.api.common.dto import Paging
5
+ from trauma.core.wrappers import TraumaResponseWrapper
6
+
7
+
8
+ class AccountWrapper(TraumaResponseWrapper[AccountModel]):
9
+ pass
10
+
11
+
12
+ class AllAccountsResponse(BaseModel):
13
+ paging: Paging
14
+ data: list[AccountModel]
trauma/api/account/views.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from fastapi import Depends, Query
4
+
5
+ from trauma.api.account import account_router
6
+ from trauma.api.account.db_requests import get_all_model_obj
7
+ from trauma.api.account.dto import AccountType
8
+ from trauma.api.account.model import AccountModel
9
+ from trauma.api.account.schemas import AccountWrapper, AllAccountsResponse
10
+ from trauma.api.common.dto import Paging
11
+ from trauma.core.security import PermissionDependency
12
+ from trauma.core.wrappers import TraumaResponseWrapper
13
+
14
+
15
+ @account_router.get('/all')
16
+ async def get_all_accounts(
17
+ pageSize: Optional[int] = Query(10, description="Number of countries to return per page"),
18
+ pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
19
+ _: AccountModel = Depends(PermissionDependency([AccountType.Admin]))
20
+ ) -> TraumaResponseWrapper[AllAccountsResponse]:
21
+ countries, total_count = await get_all_model_obj(pageSize, pageIndex)
22
+ response = AllAccountsResponse(
23
+ paging=Paging(pageSize=pageSize, pageIndex=pageIndex, totalCount=total_count),
24
+ data=countries
25
+ )
26
+ return TraumaResponseWrapper(data=response)
27
+
28
+
29
+ @account_router.get('')
30
+ async def get_account(
31
+ account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
32
+ ) -> AccountWrapper:
33
+ return AccountWrapper(data=account)
trauma/api/chat/db_requests.py CHANGED
@@ -2,6 +2,8 @@ import asyncio
2
 
3
  from fastapi import HTTPException
4
 
 
 
5
  from trauma.api.chat.model import ChatModel
6
  from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest
7
  from trauma.api.message.dto import Author
@@ -9,37 +11,43 @@ from trauma.api.message.model import MessageModel
9
  from trauma.core.config import settings
10
 
11
 
12
- async def get_chat_obj(chat_id: str) -> ChatModel:
13
  chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
14
  if not chat:
15
  raise HTTPException(status_code=404, detail="Chat not found")
16
  chat = ChatModel.from_mongo(chat)
 
 
17
  return chat
18
 
19
 
20
- async def create_chat_obj(chat_request: CreateChatRequest) -> ChatModel:
21
- chat = ChatModel(model=chat_request.model)
22
  await settings.DB_CLIENT.chats.insert_one(chat.to_mongo())
23
  return chat
24
 
25
 
26
- async def delete_chat_obj(chat_id: str) -> None:
27
  chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
28
  if not chat:
29
  raise HTTPException(status_code=404, detail="Chat not found")
30
  chat = ChatModel.from_mongo(chat)
 
 
31
  await settings.DB_CLIENT.chats.delete_one({"id": chat_id})
32
 
33
 
34
- async def update_chat_obj_title(chatId: str, chat_request: ChatTitleRequest) -> ChatModel:
35
- chat = await settings.DB_CLIENT.chats.find_one({"id": chatId})
36
  if not chat:
37
  raise HTTPException(status_code=404, detail="Chat not found")
38
 
39
  chat = ChatModel.from_mongo(chat)
 
 
40
 
41
  chat.title = chat_request.title
42
- await settings.DB_CLIENT.chats.update_one({"id": chatId}, {"$set": chat.to_mongo()})
43
  return chat
44
 
45
 
 
2
 
3
  from fastapi import HTTPException
4
 
5
+ from trauma.api.account.dto import AccountType
6
+ from trauma.api.account.model import AccountModel
7
  from trauma.api.chat.model import ChatModel
8
  from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest
9
  from trauma.api.message.dto import Author
 
11
  from trauma.core.config import settings
12
 
13
 
14
+ async def get_chat_obj(chat_id: str, account: AccountModel) -> ChatModel:
15
  chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
16
  if not chat:
17
  raise HTTPException(status_code=404, detail="Chat not found")
18
  chat = ChatModel.from_mongo(chat)
19
+ if chat.account.id != account.id and account.accountType != AccountType.Admin:
20
+ raise HTTPException(status_code=403, detail="Not authorized")
21
  return chat
22
 
23
 
24
+ async def create_chat_obj(chat_request: CreateChatRequest, account: AccountModel) -> ChatModel:
25
+ chat = ChatModel(model=chat_request.model, account=account)
26
  await settings.DB_CLIENT.chats.insert_one(chat.to_mongo())
27
  return chat
28
 
29
 
30
+ async def delete_chat_obj(chat_id: str, account: AccountModel) -> None:
31
  chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
32
  if not chat:
33
  raise HTTPException(status_code=404, detail="Chat not found")
34
  chat = ChatModel.from_mongo(chat)
35
+ if chat.account.id != account.id and account.accountType != AccountType.Admin:
36
+ raise HTTPException(status_code=403, detail="Not authorized")
37
  await settings.DB_CLIENT.chats.delete_one({"id": chat_id})
38
 
39
 
40
+ async def update_chat_obj_title(chat_id: str, chat_request: ChatTitleRequest, account: AccountModel) -> ChatModel:
41
+ chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
42
  if not chat:
43
  raise HTTPException(status_code=404, detail="Chat not found")
44
 
45
  chat = ChatModel.from_mongo(chat)
46
+ if chat.account.id != account.id and account.accountType != AccountType.Admin:
47
+ raise HTTPException(status_code=403, detail="Not authorized")
48
 
49
  chat.title = chat_request.title
50
+ await settings.DB_CLIENT.chats.update_one({"id": chat_id}, {"$set": chat.to_mongo()})
51
  return chat
52
 
53
 
trauma/api/chat/model.py CHANGED
@@ -2,6 +2,7 @@ from datetime import datetime
2
 
3
  from pydantic import Field
4
 
 
5
  from trauma.api.chat.dto import ModelType, EntityData
6
  from trauma.core.database import MongoBaseModel
7
 
@@ -10,5 +11,6 @@ class ChatModel(MongoBaseModel):
10
  title: str = 'New Chat'
11
  model: ModelType
12
  entityData: EntityData = EntityData()
 
13
  datetimeInserted: datetime = Field(default_factory=datetime.now)
14
  datetimeUpdated: datetime = Field(default_factory=datetime.now)
 
2
 
3
  from pydantic import Field
4
 
5
+ from trauma.api.account.model import AccountModel
6
  from trauma.api.chat.dto import ModelType, EntityData
7
  from trauma.core.database import MongoBaseModel
8
 
 
11
  title: str = 'New Chat'
12
  model: ModelType
13
  entityData: EntityData = EntityData()
14
+ account: AccountModel
15
  datetimeInserted: datetime = Field(default_factory=datetime.now)
16
  datetimeUpdated: datetime = Field(default_factory=datetime.now)
trauma/api/chat/views.py CHANGED
@@ -1,7 +1,9 @@
1
  from typing import Optional
2
 
3
- from fastapi import Query
4
 
 
 
5
  from trauma.api.chat import chat_router
6
  from trauma.api.chat.db_requests import (get_chat_obj,
7
  create_chat_obj,
@@ -10,6 +12,7 @@ from trauma.api.chat.db_requests import (get_chat_obj,
10
  get_all_chats_obj, save_intro_message)
11
  from trauma.api.chat.schemas import ChatWrapper, AllChatWrapper, CreateChatRequest, AllChatResponse, ChatTitleRequest
12
  from trauma.api.common.dto import Paging
 
13
  from trauma.core.wrappers import TraumaResponseWrapper
14
 
15
 
@@ -17,6 +20,7 @@ from trauma.core.wrappers import TraumaResponseWrapper
17
  async def get_all_chats(
18
  pageSize: Optional[int] = Query(10, description="Number of objects to return per page"),
19
  pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
 
20
  ) -> AllChatWrapper:
21
  chats, total_count = await get_all_chats_obj(pageSize, pageIndex)
22
  response = AllChatResponse(
@@ -28,30 +32,36 @@ async def get_all_chats(
28
 
29
  @chat_router.get('/{chatId}')
30
  async def get_chat(
31
- chatId: str
 
32
  ) -> ChatWrapper:
33
- chat = await get_chat_obj(chatId)
34
  return ChatWrapper(data=chat)
35
 
36
 
37
  @chat_router.delete('/{chatId}')
38
- async def delete_chat(chatId: str) -> TraumaResponseWrapper:
39
- await delete_chat_obj(chatId)
 
 
 
40
  return TraumaResponseWrapper()
41
 
42
 
43
  @chat_router.patch('/{chatId}/title')
44
  async def update_chat_title(
45
- chatId: str, chat: ChatTitleRequest
 
46
  ) -> ChatWrapper:
47
- chat = await update_chat_obj_title(chatId, chat)
48
  return ChatWrapper(data=chat)
49
 
50
 
51
  @chat_router.post('')
52
  async def create_chat(
53
- chat_data: CreateChatRequest
 
54
  ) -> ChatWrapper:
55
- chat = await create_chat_obj(chat_data)
56
  await save_intro_message(chat.id)
57
  return ChatWrapper(data=chat)
 
1
  from typing import Optional
2
 
3
+ from fastapi import Query, Depends
4
 
5
+ from trauma.api.account.dto import AccountType
6
+ from trauma.api.account.model import AccountModel
7
  from trauma.api.chat import chat_router
8
  from trauma.api.chat.db_requests import (get_chat_obj,
9
  create_chat_obj,
 
12
  get_all_chats_obj, save_intro_message)
13
  from trauma.api.chat.schemas import ChatWrapper, AllChatWrapper, CreateChatRequest, AllChatResponse, ChatTitleRequest
14
  from trauma.api.common.dto import Paging
15
+ from trauma.core.security import PermissionDependency
16
  from trauma.core.wrappers import TraumaResponseWrapper
17
 
18
 
 
20
  async def get_all_chats(
21
  pageSize: Optional[int] = Query(10, description="Number of objects to return per page"),
22
  pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
23
+ _: AccountModel = Depends(PermissionDependency([AccountType.Admin]))
24
  ) -> AllChatWrapper:
25
  chats, total_count = await get_all_chats_obj(pageSize, pageIndex)
26
  response = AllChatResponse(
 
32
 
33
  @chat_router.get('/{chatId}')
34
  async def get_chat(
35
+ chatId: str,
36
+ account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
37
  ) -> ChatWrapper:
38
+ chat = await get_chat_obj(chatId, account)
39
  return ChatWrapper(data=chat)
40
 
41
 
42
  @chat_router.delete('/{chatId}')
43
+ async def delete_chat(
44
+ chatId: str,
45
+ account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
46
+ ) -> TraumaResponseWrapper:
47
+ await delete_chat_obj(chatId, account)
48
  return TraumaResponseWrapper()
49
 
50
 
51
  @chat_router.patch('/{chatId}/title')
52
  async def update_chat_title(
53
+ chatId: str, chat: ChatTitleRequest,
54
+ account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
55
  ) -> ChatWrapper:
56
+ chat = await update_chat_obj_title(chatId, chat, account)
57
  return ChatWrapper(data=chat)
58
 
59
 
60
  @chat_router.post('')
61
  async def create_chat(
62
+ chat_data: CreateChatRequest,
63
+ account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
64
  ) -> ChatWrapper:
65
+ chat = await create_chat_obj(chat_data, account)
66
  await save_intro_message(chat.id)
67
  return ChatWrapper(data=chat)
trauma/api/common/db_requests.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException
2
+ from pydantic import EmailStr
3
+
4
+ from trauma.core.config import settings
5
+
6
+
7
+ async def check_unique_fields_existence(name: str,
8
+ new_value: EmailStr | str,
9
+ current_value: str | None = None) -> None:
10
+ if new_value == current_value or not new_value:
11
+ return
12
+ account = await settings.DB_CLIENT.accounts.find_one(
13
+ {name: str(new_value)},
14
+ collation={"locale": "en", "strength": 2}
15
+ )
16
+ if account:
17
+ raise HTTPException(status_code=400, detail=f'Account with specified {name} already exists.')
trauma/api/message/db_requests.py CHANGED
@@ -2,6 +2,8 @@ import asyncio
2
 
3
  from fastapi import HTTPException
4
 
 
 
5
  from trauma.api.chat.model import ChatModel
6
  from trauma.api.data.model import EntityModel
7
  from trauma.api.message.dto import Author
@@ -24,7 +26,7 @@ async def create_message_obj(
24
  return message, chat
25
 
26
 
27
- async def get_all_chat_messages_obj(chat_id: str) -> tuple[list[MessageModel], ChatModel]:
28
  messages, chat = await asyncio.gather(
29
  settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None),
30
  settings.DB_CLIENT.chats.find_one({"id": chat_id})
@@ -35,6 +37,8 @@ async def get_all_chat_messages_obj(chat_id: str) -> tuple[list[MessageModel], C
35
  raise HTTPException(status_code=404, detail="Chat not found")
36
 
37
  chat = ChatModel.from_mongo(chat)
 
 
38
  return messages, chat
39
 
40
 
 
2
 
3
  from fastapi import HTTPException
4
 
5
+ from trauma.api.account.dto import AccountType
6
+ from trauma.api.account.model import AccountModel
7
  from trauma.api.chat.model import ChatModel
8
  from trauma.api.data.model import EntityModel
9
  from trauma.api.message.dto import Author
 
26
  return message, chat
27
 
28
 
29
+ async def get_all_chat_messages_obj(chat_id: str, account: AccountModel) -> tuple[list[MessageModel], ChatModel]:
30
  messages, chat = await asyncio.gather(
31
  settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None),
32
  settings.DB_CLIENT.chats.find_one({"id": chat_id})
 
37
  raise HTTPException(status_code=404, detail="Chat not found")
38
 
39
  chat = ChatModel.from_mongo(chat)
40
+ if chat.account.id != account.id and account.accountType != AccountType.Admin:
41
+ raise HTTPException(status_code=404, detail="Not Authorized.")
42
  return messages, chat
43
 
44
 
trauma/api/message/views.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  from trauma.api.common.dto import Paging
2
  from trauma.api.message import message_router
3
  from trauma.api.message.ai.engine import search_entities
@@ -7,14 +11,16 @@ from trauma.api.message.schemas import (AllMessageWrapper,
7
  CreateMessageRequest,
8
  CreateMessageResponse)
9
  from trauma.api.message.utils import transform_messages_to_openai
 
10
  from trauma.core.wrappers import TraumaResponseWrapper
11
 
12
 
13
  @message_router.get('/{chatId}/all')
14
  async def get_all_chat_messages(
15
- chatId: str
 
16
  ) -> AllMessageWrapper:
17
- messages, _ = await get_all_chat_messages_obj(chatId)
18
  response = AllMessageResponse(
19
  paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
20
  data=messages
@@ -26,8 +32,9 @@ async def get_all_chat_messages(
26
  async def create_message(
27
  chatId: str,
28
  message_data: CreateMessageRequest,
 
29
  ) -> TraumaResponseWrapper[CreateMessageResponse]:
30
- messages, chat = await get_all_chat_messages_obj(chatId)
31
  message_history = transform_messages_to_openai(messages)
32
  response = await search_entities(message_data.text, message_history, chat)
33
  return TraumaResponseWrapper(data=response)
 
1
+ from fastapi import Depends
2
+
3
+ from trauma.api.account.dto import AccountType
4
+ from trauma.api.account.model import AccountModel
5
  from trauma.api.common.dto import Paging
6
  from trauma.api.message import message_router
7
  from trauma.api.message.ai.engine import search_entities
 
11
  CreateMessageRequest,
12
  CreateMessageResponse)
13
  from trauma.api.message.utils import transform_messages_to_openai
14
+ from trauma.core.security import PermissionDependency
15
  from trauma.core.wrappers import TraumaResponseWrapper
16
 
17
 
18
  @message_router.get('/{chatId}/all')
19
  async def get_all_chat_messages(
20
+ chatId: str,
21
+ account: AccountModel = Depends(PermissionDependency([AccountType.Admin]))
22
  ) -> AllMessageWrapper:
23
+ messages, _ = await get_all_chat_messages_obj(chatId, account)
24
  response = AllMessageResponse(
25
  paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
26
  data=messages
 
32
  async def create_message(
33
  chatId: str,
34
  message_data: CreateMessageRequest,
35
+ account: AccountModel = Depends(PermissionDependency([AccountType.Admin, AccountType.User]))
36
  ) -> TraumaResponseWrapper[CreateMessageResponse]:
37
+ messages, chat = await get_all_chat_messages_obj(chatId, account)
38
  message_history = transform_messages_to_openai(messages)
39
  response = await search_entities(message_data.text, message_history, chat)
40
  return TraumaResponseWrapper(data=response)
trauma/api/security/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+
3
+ security_router = APIRouter(
4
+ prefix='/api/security'
5
+ )
6
+
7
+ from . import views
trauma/api/security/db_requests.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ from fastapi import HTTPException
4
+
5
+ from trauma.api.account.model import AccountModel
6
+ from trauma.api.common.db_requests import check_unique_fields_existence
7
+ from trauma.api.security.schemas import RegisterAccountRequest, LoginAccountRequest
8
+ from trauma.core.config import settings
9
+ from trauma.core.security import verify_password
10
+
11
+
12
+ async def save_account(data: RegisterAccountRequest) -> AccountModel:
13
+ await check_unique_fields_existence("email", data.email)
14
+ account = AccountModel(
15
+ **data.model_dump()
16
+ )
17
+ await settings.DB_CLIENT.accounts.insert_one(account.to_mongo())
18
+ return account
19
+
20
+
21
+ async def authenticate_account(data: LoginAccountRequest) -> AccountModel:
22
+ account = await settings.DB_CLIENT.accounts.find_one(
23
+ {"email": data.email},
24
+ collation={"locale": "en", "strength": 2})
25
+ if account is None:
26
+ raise HTTPException(status_code=404, detail="Invalid email or password.")
27
+
28
+ account = AccountModel.from_mongo(account)
29
+
30
+ if not verify_password(data.password, account.password):
31
+ raise HTTPException(status_code=401, detail="Invalid email or password.")
32
+ return account
trauma/api/security/schemas.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, EmailStr
2
+
3
+ from trauma.api.account.dto import AccessToken
4
+ from trauma.api.account.model import AccountModel
5
+ from trauma.core.wrappers import TraumaResponseWrapper
6
+
7
+
8
+ class RegisterAccountRequest(BaseModel):
9
+ email: EmailStr
10
+ password: str
11
+
12
+
13
+ class RegisterAccountWrapper(TraumaResponseWrapper[AccountModel]):
14
+ pass
15
+
16
+
17
+ class LoginAccountRequest(BaseModel):
18
+ email: EmailStr
19
+ password: str
20
+
21
+
22
+ class LoginAccountResponse(BaseModel):
23
+ accessToken: AccessToken
24
+ account: AccountModel
25
+
26
+
27
+ class LoginAccountWrapper(TraumaResponseWrapper[LoginAccountResponse]):
28
+ pass
trauma/api/security/views.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from trauma.api.account.dto import AccessToken
2
+ from trauma.api.security import security_router
3
+ from trauma.api.security.db_requests import authenticate_account, save_account
4
+ from trauma.api.security.schemas import (RegisterAccountRequest,
5
+ RegisterAccountWrapper,
6
+ LoginAccountResponse,
7
+ LoginAccountWrapper,
8
+ LoginAccountRequest)
9
+ from trauma.core.security import create_access_token
10
+
11
+
12
+ @security_router.post('/register')
13
+ async def register_user(data: RegisterAccountRequest) -> RegisterAccountWrapper:
14
+ account = await save_account(data)
15
+ return RegisterAccountWrapper(data=account)
16
+
17
+
18
+ @security_router.post('/login')
19
+ async def login(data: LoginAccountRequest) -> LoginAccountWrapper:
20
+ account = await authenticate_account(data)
21
+ access_token = create_access_token(account.email, str(account.id))
22
+ response = LoginAccountResponse(
23
+ accessToken=AccessToken(value=access_token),
24
+ account=account,
25
+ )
26
+ return LoginAccountWrapper(data=response)
trauma/core/security.py CHANGED
@@ -6,6 +6,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
6
  from jose import jwt, JWTError
7
  from passlib.context import CryptContext
8
 
 
9
  from trauma.api.account.model import AccountModel
10
  from trauma.core.config import settings
11
 
@@ -29,36 +30,34 @@ def create_access_token(email: str, account_id: str):
29
 
30
 
31
  class PermissionDependency:
32
- def __init__(self, is_public: bool = False):
33
- self.is_public = is_public
34
 
35
  def __call__(
36
  self, credentials: HTTPAuthorizationCredentials | None = Depends(HTTPBearer(auto_error=False))
37
  ) -> AccountModel | None:
38
- if credentials is None:
39
- if self.is_public:
40
- return None
41
- else:
42
- raise HTTPException(status_code=403, detail="Permission denied")
43
  try:
44
  account_id = self.authenticate_jwt_token(credentials.credentials)
45
  account_data = anyio.from_thread.run(self.get_account_by_id, account_id)
46
  self.check_account_health(account_data)
47
- return account_data
48
 
49
  except JWTError:
50
  raise HTTPException(status_code=403, detail="Permission denied")
51
 
52
- async def get_account_by_id(self, account_id: str) -> AccountModel:
 
53
  account = await settings.DB_CLIENT.accounts.find_one({"id": account_id})
54
- return AccountModel.from_mongo(account)
55
 
56
- @staticmethod
57
- def check_account_health(account: AccountModel):
58
  if not account:
59
  raise HTTPException(status_code=403, detail="Permission denied")
 
 
60
 
61
- def authenticate_jwt_token(self, token: str) -> str:
 
62
  payload = jwt.decode(token,
63
  settings.SECRET_KEY,
64
  algorithms="HS256",
 
6
  from jose import jwt, JWTError
7
  from passlib.context import CryptContext
8
 
9
+ from trauma.api.account.dto import AccountType
10
  from trauma.api.account.model import AccountModel
11
  from trauma.core.config import settings
12
 
 
30
 
31
 
32
  class PermissionDependency:
33
+ def __init__(self, account_types: list[AccountType]):
34
+ self.account_types = account_types
35
 
36
  def __call__(
37
  self, credentials: HTTPAuthorizationCredentials | None = Depends(HTTPBearer(auto_error=False))
38
  ) -> AccountModel | None:
 
 
 
 
 
39
  try:
40
  account_id = self.authenticate_jwt_token(credentials.credentials)
41
  account_data = anyio.from_thread.run(self.get_account_by_id, account_id)
42
  self.check_account_health(account_data)
43
+ return AccountModel.from_mongo(account_data)
44
 
45
  except JWTError:
46
  raise HTTPException(status_code=403, detail="Permission denied")
47
 
48
+ @staticmethod
49
+ async def get_account_by_id(account_id: str) -> dict:
50
  account = await settings.DB_CLIENT.accounts.find_one({"id": account_id})
51
+ return account
52
 
53
+ def check_account_health(self, account: dict):
 
54
  if not account:
55
  raise HTTPException(status_code=403, detail="Permission denied")
56
+ if account['accountType'] not in [i.value for i in self.account_types]:
57
+ raise HTTPException(status_code=403, detail="Permission denied")
58
 
59
+ @staticmethod
60
+ def authenticate_jwt_token(token: str) -> str:
61
  payload = jwt.decode(token,
62
  settings.SECRET_KEY,
63
  algorithms="HS256",