Spaces:
Running
Running
import asyncio | |
import numpy as np | |
from fastapi import HTTPException | |
from trauma.api.chat.model import ChatModel | |
from trauma.api.data.model import EntityModel | |
from trauma.api.message.ai.openai_request import convert_value_to_embeddings | |
from trauma.api.message.dto import Author | |
from trauma.api.message.model import MessageModel | |
from trauma.api.message.schemas import CreateMessageRequest | |
from trauma.core.config import settings | |
from trauma.core.wrappers import background_task | |
async def create_message_obj( | |
chat_id: str, message_data: CreateMessageRequest | |
) -> tuple[MessageModel, ChatModel]: | |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id}) | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found.") | |
message = MessageModel(**message_data.model_dump(), chatId=chat_id, author=Author.User) | |
await settings.DB_CLIENT.messages.insert_one(message.to_mongo()) | |
return message, chat | |
async def get_all_chat_messages_obj(chat_id: str) -> tuple[list[MessageModel], ChatModel]: | |
messages, chat = await asyncio.gather( | |
settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None), | |
settings.DB_CLIENT.chats.find_one({"id": chat_id}) | |
) | |
messages = [MessageModel.from_mongo(message) for message in messages] | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
chat = ChatModel.from_mongo(chat) | |
return messages, chat | |
async def update_entity_data_obj(entity_data: dict, chat_id: str) -> None: | |
await settings.DB_CLIENT.chats.update_one( | |
{"id": chat_id}, | |
{"$set": { | |
"entityData": entity_data, | |
}} | |
) | |
async def save_assistant_user_message(user_message: str, assistant_message: str, chat_id: str) -> None: | |
user_message = MessageModel(chatId=chat_id, author=Author.User, text=user_message) | |
assistant_message = MessageModel(chatId=chat_id, author=Author.Assistant, text=assistant_message) | |
await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo()) | |
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo()) | |
async def filter_entities_by_age(entity: dict) -> list[int]: | |
query = { | |
"ageGroups": { | |
"$elemMatch": { | |
"ageMin": {"$lte": entity['ageMax']}, | |
"ageMax": {"$gte": entity['ageMin']} | |
} | |
} | |
} | |
entities = await settings.DB_CLIENT.entities.find(query).to_list(length=None) | |
return [entity['index'] for entity in entities] | |
async def get_entity_by_index(index: int) -> EntityModel: | |
entity = await settings.DB_CLIENT.entities.find_one({"index": index}) | |
return EntityModel.from_mongo(entity) | |
async def search_semantic_entities(search_request: str, entities_indexes: list[int]) -> list[EntityModel]: | |
embedding = await convert_value_to_embeddings(search_request) | |
query_embedding = np.array([embedding], dtype=np.float32) | |
distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal) | |
distances = distances[0] | |
indices = indices[0] | |
filtered_results = [ | |
{"index": int(idx), "distance": float(dist)} | |
for idx, dist in zip(indices, distances) | |
if idx in entities_indexes and dist <= 1.3 | |
] | |
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5] | |
final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results]) | |
return final_entities | |