brestok's picture
finish ai
3ec35ef
raw
history blame
3.54 kB
import asyncio
import numpy as np
from fastapi import HTTPException
from trauma.api.chat.model import ChatModel
from trauma.api.data.model import EntityModel
from trauma.api.message.ai.openai_request import convert_value_to_embeddings
from trauma.api.message.dto import Author
from trauma.api.message.model import MessageModel
from trauma.api.message.schemas import CreateMessageRequest
from trauma.core.config import settings
from trauma.core.wrappers import background_task
async def create_message_obj(
chat_id: str, message_data: CreateMessageRequest
) -> tuple[MessageModel, ChatModel]:
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
if not chat:
raise HTTPException(status_code=404, detail="Chat not found.")
message = MessageModel(**message_data.model_dump(), chatId=chat_id, author=Author.User)
await settings.DB_CLIENT.messages.insert_one(message.to_mongo())
return message, chat
async def get_all_chat_messages_obj(chat_id: str) -> tuple[list[MessageModel], ChatModel]:
messages, chat = await asyncio.gather(
settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None),
settings.DB_CLIENT.chats.find_one({"id": chat_id})
)
messages = [MessageModel.from_mongo(message) for message in messages]
if not chat:
raise HTTPException(status_code=404, detail="Chat not found")
chat = ChatModel.from_mongo(chat)
return messages, chat
@background_task()
async def update_entity_data_obj(entity_data: dict, chat_id: str) -> None:
await settings.DB_CLIENT.chats.update_one(
{"id": chat_id},
{"$set": {
"entityData": entity_data,
}}
)
@background_task()
async def save_assistant_user_message(user_message: str, assistant_message: str, chat_id: str) -> None:
user_message = MessageModel(chatId=chat_id, author=Author.User, text=user_message)
assistant_message = MessageModel(chatId=chat_id, author=Author.Assistant, text=assistant_message)
await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
async def filter_entities_by_age(entity: dict) -> list[int]:
query = {
"ageGroups": {
"$elemMatch": {
"ageMin": {"$lte": entity['ageMax']},
"ageMax": {"$gte": entity['ageMin']}
}
}
}
entities = await settings.DB_CLIENT.entities.find(query).to_list(length=None)
return [entity['index'] for entity in entities]
async def get_entity_by_index(index: int) -> EntityModel:
entity = await settings.DB_CLIENT.entities.find_one({"index": index})
return EntityModel.from_mongo(entity)
async def search_semantic_entities(search_request: str, entities_indexes: list[int]) -> list[EntityModel]:
embedding = await convert_value_to_embeddings(search_request)
query_embedding = np.array([embedding], dtype=np.float32)
distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal)
distances = distances[0]
indices = indices[0]
filtered_results = [
{"index": int(idx), "distance": float(dist)}
for idx, dist in zip(indices, distances)
if idx in entities_indexes and dist <= 1.3
]
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
return final_entities