import asyncio from fastapi import HTTPException from trauma.api.account.dto import AccountType from trauma.api.account.model import AccountModel from trauma.api.chat.model import ChatModel from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest from trauma.api.message.dto import Author from trauma.api.message.model import MessageModel from trauma.core.config import settings async def get_chat_obj(chat_id: str, account: AccountModel) -> ChatModel: chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id}) if not chat: raise HTTPException(status_code=404, detail="Chat not found") chat = ChatModel.from_mongo(chat) if chat.account.id != account.id and account.accountType != AccountType.Admin: raise HTTPException(status_code=403, detail="Not authorized") return chat async def create_chat_obj(chat_request: CreateChatRequest, account: AccountModel) -> ChatModel: chat = ChatModel(model=chat_request.model, account=account) await settings.DB_CLIENT.chats.insert_one(chat.to_mongo()) return chat async def delete_chat_obj(chat_id: str, account: AccountModel) -> None: chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id}) if not chat: raise HTTPException(status_code=404, detail="Chat not found") chat = ChatModel.from_mongo(chat) if chat.account.id != account.id and account.accountType != AccountType.Admin: raise HTTPException(status_code=403, detail="Not authorized") await settings.DB_CLIENT.chats.delete_one({"id": chat_id}) async def update_chat_obj_title(chat_id: str, chat_request: ChatTitleRequest, account: AccountModel) -> ChatModel: chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id}) if not chat: raise HTTPException(status_code=404, detail="Chat not found") chat = ChatModel.from_mongo(chat) if chat.account.id != account.id and account.accountType != AccountType.Admin: raise HTTPException(status_code=403, detail="Not authorized") chat.title = chat_request.title await settings.DB_CLIENT.chats.update_one({"id": chat_id}, {"$set": chat.to_mongo()}) return chat async def get_all_chats_obj(page_size: int, page_index: int) -> tuple[list[ChatModel], int]: skip = page_size * page_index objects, total_count = await asyncio.gather( settings.DB_CLIENT.chats .find({}) .sort("_id", -1) .skip(skip) .limit(page_size) .to_list(length=page_size), settings.DB_CLIENT.chats.count_documents({}), ) return objects, total_count async def save_intro_message(chat_id: str) -> None: message = settings.INTRO_MESSAGE message = MessageModel(author=Author.Assistant, text=message, chatId=chat_id) await settings.DB_CLIENT.messages.insert_one(message.to_mongo())