Spaces:
Running
Running
import asyncio | |
from fastapi import HTTPException | |
from trauma.api.account.dto import AccountType | |
from trauma.api.account.model import AccountModel | |
from trauma.api.chat.model import ChatModel | |
from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest | |
from trauma.api.message.dto import Author | |
from trauma.api.message.model import MessageModel | |
from trauma.core.config import settings | |
async def get_chat_obj(chat_id: str, account: AccountModel) -> ChatModel: | |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id}) | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
chat = ChatModel.from_mongo(chat) | |
if chat.account.id != account.id and account.accountType != AccountType.Admin: | |
raise HTTPException(status_code=403, detail="Not authorized") | |
return chat | |
async def create_chat_obj(chat_request: CreateChatRequest, account: AccountModel) -> ChatModel: | |
chat = ChatModel(model=chat_request.model, account=account) | |
await settings.DB_CLIENT.chats.insert_one(chat.to_mongo()) | |
return chat | |
async def delete_chat_obj(chat_id: str, account: AccountModel) -> None: | |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id}) | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
chat = ChatModel.from_mongo(chat) | |
if chat.account.id != account.id and account.accountType != AccountType.Admin: | |
raise HTTPException(status_code=403, detail="Not authorized") | |
await settings.DB_CLIENT.chats.delete_one({"id": chat_id}) | |
async def update_chat_obj_title(chat_id: str, chat_request: ChatTitleRequest, account: AccountModel) -> ChatModel: | |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id}) | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
chat = ChatModel.from_mongo(chat) | |
if chat.account.id != account.id and account.accountType != AccountType.Admin: | |
raise HTTPException(status_code=403, detail="Not authorized") | |
chat.title = chat_request.title | |
await settings.DB_CLIENT.chats.update_one({"id": chat_id}, {"$set": chat.to_mongo()}) | |
return chat | |
async def get_all_chats_obj(page_size: int, page_index: int) -> tuple[list[ChatModel], int]: | |
skip = page_size * page_index | |
objects, total_count = await asyncio.gather( | |
settings.DB_CLIENT.chats | |
.find({}) | |
.sort("_id", -1) | |
.skip(skip) | |
.limit(page_size) | |
.to_list(length=page_size), | |
settings.DB_CLIENT.chats.count_documents({}), | |
) | |
return objects, total_count | |
async def save_intro_message(chat_id: str) -> None: | |
message = settings.INTRO_MESSAGE | |
message = MessageModel(author=Author.Assistant, text=message, chatId=chat_id) | |
await settings.DB_CLIENT.messages.insert_one(message.to_mongo()) | |