Spaces:
Running
Running
import asyncio | |
from fastapi import HTTPException | |
from trauma.api.account.model import AccountModel | |
from trauma.api.chat.model import ChatModel | |
from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest | |
from trauma.core.config import settings | |
async def get_chat_obj(chat_id: str, account: AccountModel | None) -> ChatModel: | |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id}) | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
chat = ChatModel.from_mongo(chat) | |
if account and chat.account != account: | |
raise HTTPException(status_code=403, detail="Chat account not match") | |
return chat | |
async def create_chat_obj(chat_request: CreateChatRequest, account: AccountModel | None) -> ChatModel: | |
chat = ChatModel(model=chat_request.model, account=account) | |
await settings.DB_CLIENT.chats.insert_one(chat.to_mongo()) | |
return chat | |
async def delete_chat_obj(chat_id: str, account: AccountModel | None) -> None: | |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id}) | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
chat = ChatModel.from_mongo(chat) | |
if account and chat.account != account: | |
raise HTTPException(status_code=403, detail="Chat account not match") | |
await settings.DB_CLIENT.chats.delete_one({"id": chat_id}) | |
async def update_chat_obj_title(chatId: str, chat_request: ChatTitleRequest, account: AccountModel | None) -> ChatModel: | |
chat = await settings.DB_CLIENT.chats.find_one({"id": chatId}) | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
chat = ChatModel.from_mongo(chat) | |
if account and chat.account != account: | |
raise HTTPException(status_code=403, detail="Chat account not match") | |
chat.title = chat_request.title | |
await settings.DB_CLIENT.chats.update_one({"id": chatId}, {"$set": chat.to_mongo()}) | |
return chat | |
async def get_all_chats_obj(page_size: int, page_index: int, account: AccountModel) -> tuple[list[ChatModel], int]: | |
query = {"account.id": account.id} | |
skip = page_size * page_index | |
objects, total_count = await asyncio.gather( | |
settings.DB_CLIENT.chats | |
.find(query) | |
.sort("_id", -1) | |
.skip(skip) | |
.limit(page_size) | |
.to_list(length=page_size), | |
settings.DB_CLIENT.chats.count_documents(query), | |
) | |
return objects, total_count | |