Spaces:
Running
Running
finish backend
Browse files- .gitignore +1 -1
- requirements.txt +1 -1
- test.py +28 -0
- trauma/__init__.py +0 -5
- trauma/api/account/__init__.py +0 -7
- trauma/api/account/dto.py +0 -6
- trauma/api/account/model.py +0 -32
- trauma/api/account/schemas.py +0 -6
- trauma/api/account/views.py +0 -13
- trauma/api/chat/db_requests.py +19 -16
- trauma/api/chat/dto.py +8 -0
- trauma/api/chat/model.py +2 -3
- trauma/api/chat/views.py +9 -13
- trauma/api/data/dto.py +13 -0
- trauma/api/data/model.py +43 -0
- trauma/api/data/prepare_data.py +83 -0
- trauma/api/message/ai/engine.py +35 -0
- trauma/api/message/ai/openai_request.py +71 -45
- trauma/api/message/ai/prompts.py +5 -216
- trauma/api/message/ai/utils.py +0 -30
- trauma/api/message/db_requests.py +52 -19
- trauma/api/message/model.py +1 -2
- trauma/api/message/schemas.py +6 -0
- trauma/api/message/utils.py +39 -0
- trauma/api/message/views.py +13 -14
- trauma/api/security/__init__.py +0 -7
- trauma/api/security/db_requests.py +0 -34
- trauma/api/security/schemas.py +0 -28
- trauma/api/security/views.py +0 -27
- trauma/core/config.py +2 -1
- trauma/core/wrappers.py +41 -0
.gitignore
CHANGED
@@ -5,7 +5,7 @@ venv/
|
|
5 |
.idea/
|
6 |
*.log
|
7 |
*.egg-info/
|
8 |
-
pip-wheel-
|
9 |
.env
|
10 |
.DS_Store
|
11 |
static/
|
|
|
5 |
.idea/
|
6 |
*.log
|
7 |
*.egg-info/
|
8 |
+
pip-wheel-EntityData/
|
9 |
.env
|
10 |
.DS_Store
|
11 |
static/
|
requirements.txt
CHANGED
@@ -33,4 +33,4 @@ typing_extensions==4.12.2
|
|
33 |
uvicorn==0.32.1
|
34 |
uvloop==0.21.0
|
35 |
watchfiles==1.0.0
|
36 |
-
websockets==14.1
|
|
|
33 |
uvicorn==0.32.1
|
34 |
uvloop==0.21.0
|
35 |
watchfiles==1.0.0
|
36 |
+
websockets==14.1
|
test.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from translate import Translator
|
3 |
+
|
4 |
+
|
5 |
+
def convert_and_translate_headers(input_file: str, output_file: str, sheet_name: str = None):
|
6 |
+
try:
|
7 |
+
# Читаем файл Excel
|
8 |
+
data = pd.read_excel(input_file, sheet_name=sheet_name)
|
9 |
+
|
10 |
+
# Инициализируем переводчик
|
11 |
+
translator = Translator(from_lang='nl', to_lang='en')
|
12 |
+
|
13 |
+
# Переводим названия колонок
|
14 |
+
translated_columns = {col: translator.translate(col) for col in data.columns}
|
15 |
+
data.rename(columns=translated_columns, inplace=True)
|
16 |
+
|
17 |
+
# Сохраняем преобразованные данные в CSV
|
18 |
+
data.to_csv(output_file, index=False)
|
19 |
+
print(f"Файл успешно конвертирован и сохранен: {output_file}")
|
20 |
+
except Exception as e:
|
21 |
+
print(f"Произошла ошибка: {e}")
|
22 |
+
|
23 |
+
|
24 |
+
input_xlsx = "test.xlsx" # Путь к входному .xlsx файлу
|
25 |
+
output_csv = "translated_output.csv" # Путь к выходному .csv файлу
|
26 |
+
sheet = "Sheet1" # Укажите имя листа, если нужно
|
27 |
+
|
28 |
+
convert_and_translate_headers(input_xlsx, output_csv, sheet)
|
trauma/__init__.py
CHANGED
@@ -12,17 +12,12 @@ from trauma.core.wrappers import TraumaResponseWrapper, ErrorTraumaResponse
|
|
12 |
def create_app() -> FastAPI:
|
13 |
app = FastAPI()
|
14 |
|
15 |
-
from trauma.api.account import account_router
|
16 |
-
app.include_router(account_router, tags=['account'])
|
17 |
-
|
18 |
from trauma.api.chat import chat_router
|
19 |
app.include_router(chat_router, tags=['chat'])
|
20 |
|
21 |
from trauma.api.message import message_router
|
22 |
app.include_router(message_router, tags=['message'])
|
23 |
|
24 |
-
from trauma.api.security import security_router
|
25 |
-
app.include_router(security_router, tags=['security'])
|
26 |
|
27 |
app.add_middleware(
|
28 |
CORSMiddleware,
|
|
|
12 |
def create_app() -> FastAPI:
|
13 |
app = FastAPI()
|
14 |
|
|
|
|
|
|
|
15 |
from trauma.api.chat import chat_router
|
16 |
app.include_router(chat_router, tags=['chat'])
|
17 |
|
18 |
from trauma.api.message import message_router
|
19 |
app.include_router(message_router, tags=['message'])
|
20 |
|
|
|
|
|
21 |
|
22 |
app.add_middleware(
|
23 |
CORSMiddleware,
|
trauma/api/account/__init__.py
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
from fastapi import APIRouter
|
2 |
-
|
3 |
-
account_router = APIRouter(
|
4 |
-
prefix='/api/account'
|
5 |
-
)
|
6 |
-
|
7 |
-
from . import views
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/account/dto.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
from pydantic import BaseModel
|
2 |
-
|
3 |
-
|
4 |
-
class AccessToken(BaseModel):
|
5 |
-
type: str = "Bearer"
|
6 |
-
value: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/account/model.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
from datetime import datetime
|
2 |
-
|
3 |
-
from pydantic import field_validator, Field, EmailStr
|
4 |
-
from passlib.context import CryptContext
|
5 |
-
|
6 |
-
from trauma.core.database import MongoBaseModel
|
7 |
-
|
8 |
-
|
9 |
-
class AccountModel(MongoBaseModel):
|
10 |
-
email: EmailStr
|
11 |
-
password: str | None = Field(exclude=True, default=None)
|
12 |
-
|
13 |
-
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
14 |
-
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
15 |
-
|
16 |
-
@field_validator("datetimeUpdated", mode="before", check_fields=False)
|
17 |
-
@classmethod
|
18 |
-
def validate_datetimeUpdated(cls, v):
|
19 |
-
return v or datetime.now()
|
20 |
-
|
21 |
-
@field_validator('password', mode='before', check_fields=False)
|
22 |
-
@classmethod
|
23 |
-
def set_password_hash(cls, v):
|
24 |
-
if not v.startswith("$2b$"):
|
25 |
-
return CryptContext(schemes=["bcrypt"], deprecated="auto").hash(v)
|
26 |
-
return v
|
27 |
-
|
28 |
-
class Config:
|
29 |
-
arbitrary_types_allowed = True
|
30 |
-
json_encoders = {
|
31 |
-
datetime: lambda dt: dt.isoformat()
|
32 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/account/schemas.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
from trauma.api.account.model import AccountModel
|
2 |
-
from trauma.core.wrappers import TraumaResponseWrapper
|
3 |
-
|
4 |
-
|
5 |
-
class AccountWrapper(TraumaResponseWrapper[AccountModel]):
|
6 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/account/views.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
from fastapi import Depends
|
2 |
-
|
3 |
-
from trauma.api.account import account_router
|
4 |
-
from trauma.api.account.model import AccountModel
|
5 |
-
from trauma.api.account.schemas import AccountWrapper
|
6 |
-
from trauma.core.security import PermissionDependency
|
7 |
-
|
8 |
-
|
9 |
-
@account_router.get('')
|
10 |
-
async def get_account(
|
11 |
-
account: AccountModel = Depends(PermissionDependency())
|
12 |
-
) -> AccountWrapper:
|
13 |
-
return AccountWrapper(data=account)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/chat/db_requests.py
CHANGED
@@ -2,62 +2,65 @@ import asyncio
|
|
2 |
|
3 |
from fastapi import HTTPException
|
4 |
|
5 |
-
from trauma.api.account.model import AccountModel
|
6 |
from trauma.api.chat.model import ChatModel
|
7 |
from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest
|
8 |
from trauma.core.config import settings
|
|
|
9 |
|
10 |
|
11 |
-
async def get_chat_obj(chat_id: str
|
12 |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
13 |
if not chat:
|
14 |
raise HTTPException(status_code=404, detail="Chat not found")
|
15 |
chat = ChatModel.from_mongo(chat)
|
16 |
-
if account and chat.account != account:
|
17 |
-
raise HTTPException(status_code=403, detail="Chat account not match")
|
18 |
return chat
|
19 |
|
20 |
|
21 |
-
async def create_chat_obj(chat_request: CreateChatRequest
|
22 |
-
chat = ChatModel(model=chat_request.model
|
23 |
await settings.DB_CLIENT.chats.insert_one(chat.to_mongo())
|
24 |
return chat
|
25 |
|
26 |
|
27 |
-
async def delete_chat_obj(chat_id: str
|
28 |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
29 |
if not chat:
|
30 |
raise HTTPException(status_code=404, detail="Chat not found")
|
31 |
chat = ChatModel.from_mongo(chat)
|
32 |
-
if account and chat.account != account:
|
33 |
-
raise HTTPException(status_code=403, detail="Chat account not match")
|
34 |
await settings.DB_CLIENT.chats.delete_one({"id": chat_id})
|
35 |
|
36 |
|
37 |
-
async def update_chat_obj_title(chatId: str, chat_request: ChatTitleRequest
|
38 |
chat = await settings.DB_CLIENT.chats.find_one({"id": chatId})
|
39 |
if not chat:
|
40 |
raise HTTPException(status_code=404, detail="Chat not found")
|
41 |
|
42 |
chat = ChatModel.from_mongo(chat)
|
43 |
-
if account and chat.account != account:
|
44 |
-
raise HTTPException(status_code=403, detail="Chat account not match")
|
45 |
|
46 |
chat.title = chat_request.title
|
47 |
await settings.DB_CLIENT.chats.update_one({"id": chatId}, {"$set": chat.to_mongo()})
|
48 |
return chat
|
49 |
|
50 |
|
51 |
-
async def get_all_chats_obj(page_size: int, page_index: int
|
52 |
-
query = {"account.id": account.id}
|
53 |
skip = page_size * page_index
|
54 |
objects, total_count = await asyncio.gather(
|
55 |
settings.DB_CLIENT.chats
|
56 |
-
.find(
|
57 |
.sort("_id", -1)
|
58 |
.skip(skip)
|
59 |
.limit(page_size)
|
60 |
.to_list(length=page_size),
|
61 |
-
settings.DB_CLIENT.chats.count_documents(
|
62 |
)
|
63 |
return objects, total_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
from fastapi import HTTPException
|
4 |
|
|
|
5 |
from trauma.api.chat.model import ChatModel
|
6 |
from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest
|
7 |
from trauma.core.config import settings
|
8 |
+
from trauma.core.wrappers import background_task
|
9 |
|
10 |
|
11 |
+
async def get_chat_obj(chat_id: str) -> ChatModel:
|
12 |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
13 |
if not chat:
|
14 |
raise HTTPException(status_code=404, detail="Chat not found")
|
15 |
chat = ChatModel.from_mongo(chat)
|
|
|
|
|
16 |
return chat
|
17 |
|
18 |
|
19 |
+
async def create_chat_obj(chat_request: CreateChatRequest) -> ChatModel:
|
20 |
+
chat = ChatModel(model=chat_request.model)
|
21 |
await settings.DB_CLIENT.chats.insert_one(chat.to_mongo())
|
22 |
return chat
|
23 |
|
24 |
|
25 |
+
async def delete_chat_obj(chat_id: str) -> None:
|
26 |
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
27 |
if not chat:
|
28 |
raise HTTPException(status_code=404, detail="Chat not found")
|
29 |
chat = ChatModel.from_mongo(chat)
|
|
|
|
|
30 |
await settings.DB_CLIENT.chats.delete_one({"id": chat_id})
|
31 |
|
32 |
|
33 |
+
async def update_chat_obj_title(chatId: str, chat_request: ChatTitleRequest) -> ChatModel:
|
34 |
chat = await settings.DB_CLIENT.chats.find_one({"id": chatId})
|
35 |
if not chat:
|
36 |
raise HTTPException(status_code=404, detail="Chat not found")
|
37 |
|
38 |
chat = ChatModel.from_mongo(chat)
|
|
|
|
|
39 |
|
40 |
chat.title = chat_request.title
|
41 |
await settings.DB_CLIENT.chats.update_one({"id": chatId}, {"$set": chat.to_mongo()})
|
42 |
return chat
|
43 |
|
44 |
|
45 |
+
async def get_all_chats_obj(page_size: int, page_index: int) -> tuple[list[ChatModel], int]:
|
|
|
46 |
skip = page_size * page_index
|
47 |
objects, total_count = await asyncio.gather(
|
48 |
settings.DB_CLIENT.chats
|
49 |
+
.find({})
|
50 |
.sort("_id", -1)
|
51 |
.skip(skip)
|
52 |
.limit(page_size)
|
53 |
.to_list(length=page_size),
|
54 |
+
settings.DB_CLIENT.chats.count_documents({}),
|
55 |
)
|
56 |
return objects, total_count
|
57 |
+
|
58 |
+
|
59 |
+
@background_task()
|
60 |
+
async def update_entity_data_obj(entity_data: dict, chat_id: str) -> None:
|
61 |
+
await settings.DB_CLIENT.chats.update_one(
|
62 |
+
{"id": chat_id},
|
63 |
+
{"$set": {
|
64 |
+
"entityData": entity_data,
|
65 |
+
}}
|
66 |
+
)
|
trauma/api/chat/dto.py
CHANGED
@@ -1,8 +1,16 @@
|
|
1 |
from enum import Enum
|
2 |
|
|
|
|
|
3 |
|
4 |
class ModelType(Enum):
|
5 |
gpt_4o = "gpt-4o"
|
6 |
gpt_4o_mini = "gpt-4o-mini"
|
7 |
o1_mini = "o1-mini"
|
8 |
o1_preview = "o1-preview"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from enum import Enum
|
2 |
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
|
6 |
class ModelType(Enum):
|
7 |
gpt_4o = "gpt-4o"
|
8 |
gpt_4o_mini = "gpt-4o-mini"
|
9 |
o1_mini = "o1-mini"
|
10 |
o1_preview = "o1-preview"
|
11 |
+
|
12 |
+
class EntityData(BaseModel):
|
13 |
+
ageMin: int | None = None
|
14 |
+
ageMax: int | None = None
|
15 |
+
treatmentAreas: str | None = None
|
16 |
+
treatmentMethods: str | None = None
|
trauma/api/chat/model.py
CHANGED
@@ -2,14 +2,13 @@ from datetime import datetime
|
|
2 |
|
3 |
from pydantic import Field
|
4 |
|
5 |
-
from trauma.api.
|
6 |
-
from trauma.api.chat.dto import ModelType
|
7 |
from trauma.core.database import MongoBaseModel
|
8 |
|
9 |
|
10 |
class ChatModel(MongoBaseModel):
|
11 |
title: str = 'New Chat'
|
12 |
model: ModelType
|
13 |
-
|
14 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
15 |
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
|
|
2 |
|
3 |
from pydantic import Field
|
4 |
|
5 |
+
from trauma.api.chat.dto import ModelType, EntityData
|
|
|
6 |
from trauma.core.database import MongoBaseModel
|
7 |
|
8 |
|
9 |
class ChatModel(MongoBaseModel):
|
10 |
title: str = 'New Chat'
|
11 |
model: ModelType
|
12 |
+
entityData: EntityData | None = None
|
13 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
14 |
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
trauma/api/chat/views.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
from typing import Optional
|
2 |
|
3 |
from fastapi import Query
|
4 |
-
from fastapi.params import Depends
|
5 |
|
6 |
-
from trauma.api.account.model import AccountModel
|
7 |
from trauma.api.chat import chat_router
|
8 |
from trauma.api.chat.db_requests import (get_chat_obj,
|
9 |
create_chat_obj,
|
@@ -12,7 +10,6 @@ from trauma.api.chat.db_requests import (get_chat_obj,
|
|
12 |
get_all_chats_obj)
|
13 |
from trauma.api.chat.schemas import ChatWrapper, AllChatWrapper, CreateChatRequest, AllChatResponse, ChatTitleRequest
|
14 |
from trauma.api.common.dto import Paging
|
15 |
-
from trauma.core.security import PermissionDependency
|
16 |
from trauma.core.wrappers import TraumaResponseWrapper
|
17 |
|
18 |
|
@@ -20,9 +17,8 @@ from trauma.core.wrappers import TraumaResponseWrapper
|
|
20 |
async def get_all_chats(
|
21 |
pageSize: Optional[int] = Query(10, description="Number of objects to return per page"),
|
22 |
pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
|
23 |
-
account: AccountModel = Depends(PermissionDependency())
|
24 |
) -> AllChatWrapper:
|
25 |
-
chats, total_count = await get_all_chats_obj(pageSize, pageIndex
|
26 |
response = AllChatResponse(
|
27 |
paging=Paging(pageSize=pageSize, pageIndex=pageIndex, totalCount=total_count),
|
28 |
data=chats
|
@@ -32,29 +28,29 @@ async def get_all_chats(
|
|
32 |
|
33 |
@chat_router.get('/{chatId}')
|
34 |
async def get_chat(
|
35 |
-
chatId: str
|
36 |
) -> ChatWrapper:
|
37 |
-
chat = await get_chat_obj(chatId
|
38 |
return ChatWrapper(data=chat)
|
39 |
|
40 |
|
41 |
@chat_router.post('')
|
42 |
async def create_chat(
|
43 |
-
chat_data: CreateChatRequest
|
44 |
) -> ChatWrapper:
|
45 |
-
chat = await create_chat_obj(chat_data
|
46 |
return ChatWrapper(data=chat)
|
47 |
|
48 |
|
49 |
@chat_router.delete('/{chatId}')
|
50 |
-
async def delete_chat(chatId: str
|
51 |
-
await delete_chat_obj(chatId
|
52 |
return TraumaResponseWrapper()
|
53 |
|
54 |
|
55 |
@chat_router.patch('/{chatId}/title')
|
56 |
async def update_chat_title(
|
57 |
-
chatId: str, chat: ChatTitleRequest
|
58 |
) -> ChatWrapper:
|
59 |
-
chat = await update_chat_obj_title(chatId, chat
|
60 |
return ChatWrapper(data=chat)
|
|
|
1 |
from typing import Optional
|
2 |
|
3 |
from fastapi import Query
|
|
|
4 |
|
|
|
5 |
from trauma.api.chat import chat_router
|
6 |
from trauma.api.chat.db_requests import (get_chat_obj,
|
7 |
create_chat_obj,
|
|
|
10 |
get_all_chats_obj)
|
11 |
from trauma.api.chat.schemas import ChatWrapper, AllChatWrapper, CreateChatRequest, AllChatResponse, ChatTitleRequest
|
12 |
from trauma.api.common.dto import Paging
|
|
|
13 |
from trauma.core.wrappers import TraumaResponseWrapper
|
14 |
|
15 |
|
|
|
17 |
async def get_all_chats(
|
18 |
pageSize: Optional[int] = Query(10, description="Number of objects to return per page"),
|
19 |
pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
|
|
|
20 |
) -> AllChatWrapper:
|
21 |
+
chats, total_count = await get_all_chats_obj(pageSize, pageIndex)
|
22 |
response = AllChatResponse(
|
23 |
paging=Paging(pageSize=pageSize, pageIndex=pageIndex, totalCount=total_count),
|
24 |
data=chats
|
|
|
28 |
|
29 |
@chat_router.get('/{chatId}')
|
30 |
async def get_chat(
|
31 |
+
chatId: str
|
32 |
) -> ChatWrapper:
|
33 |
+
chat = await get_chat_obj(chatId)
|
34 |
return ChatWrapper(data=chat)
|
35 |
|
36 |
|
37 |
@chat_router.post('')
|
38 |
async def create_chat(
|
39 |
+
chat_data: CreateChatRequest
|
40 |
) -> ChatWrapper:
|
41 |
+
chat = await create_chat_obj(chat_data)
|
42 |
return ChatWrapper(data=chat)
|
43 |
|
44 |
|
45 |
@chat_router.delete('/{chatId}')
|
46 |
+
async def delete_chat(chatId: str) -> TraumaResponseWrapper:
|
47 |
+
await delete_chat_obj(chatId)
|
48 |
return TraumaResponseWrapper()
|
49 |
|
50 |
|
51 |
@chat_router.patch('/{chatId}/title')
|
52 |
async def update_chat_title(
|
53 |
+
chatId: str, chat: ChatTitleRequest
|
54 |
) -> ChatWrapper:
|
55 |
+
chat = await update_chat_obj_title(chatId, chat)
|
56 |
return ChatWrapper(data=chat)
|
trauma/api/data/dto.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
|
4 |
+
class AgeGroup(BaseModel):
|
5 |
+
ageMin: int
|
6 |
+
ageMax: int
|
7 |
+
|
8 |
+
class ContactDetails(BaseModel):
|
9 |
+
phone: str | None = None
|
10 |
+
email: str | None = None
|
11 |
+
website: str | None = None
|
12 |
+
address: str | None = None
|
13 |
+
postalCode: str | None = None
|
trauma/api/data/model.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
|
3 |
+
from trauma.api.data.dto import AgeGroup, ContactDetails
|
4 |
+
from trauma.core.database import MongoBaseModel
|
5 |
+
|
6 |
+
|
7 |
+
class DataModel(MongoBaseModel):
|
8 |
+
startDate: datetime.datetime
|
9 |
+
endDate: datetime.datetime
|
10 |
+
email: str | None = None
|
11 |
+
name: str | None = None
|
12 |
+
datetimeUpdated: datetime.datetime | None = None
|
13 |
+
organizationName: str | None = None
|
14 |
+
organizationLocation: str
|
15 |
+
youthCareRegion: str
|
16 |
+
postalCode: str
|
17 |
+
hasOrganizationMultipleLocations: bool
|
18 |
+
publicEmail: str
|
19 |
+
website: str
|
20 |
+
isTraumaTreatmentVisible: bool
|
21 |
+
injuryTreatmentForms: list[str]
|
22 |
+
offers: list[str]
|
23 |
+
intensities: list[str]
|
24 |
+
injuryTreatmentDuration: list[str]
|
25 |
+
ageGroups: list[AgeGroup]
|
26 |
+
traumaIndication: list[str]
|
27 |
+
supplier: str
|
28 |
+
treatmentFinancing: list[str]
|
29 |
+
isContractedCare: list[bool]
|
30 |
+
traumaTeamCharacteristics: list[str]
|
31 |
+
traumaDisciplines: list[str]
|
32 |
+
practitionersQualifications: list[str]
|
33 |
+
AQAs: list[str]
|
34 |
+
additions: list[str]
|
35 |
+
|
36 |
+
|
37 |
+
class EntityModel(MongoBaseModel):
|
38 |
+
id: int
|
39 |
+
name: str
|
40 |
+
ageGroups: list[AgeGroup]
|
41 |
+
treatmentAreas: list[str]
|
42 |
+
treatmentMethods: list[str]
|
43 |
+
contactDetails: ContactDetails
|
trauma/api/data/prepare_data.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
import faiss
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from trauma.api.data.model import EntityModel
|
7 |
+
from trauma.api.message.ai.openai_request import convert_value_to_embeddings
|
8 |
+
# import re
|
9 |
+
#
|
10 |
+
# import pandas as pd
|
11 |
+
#
|
12 |
+
# from trauma.api.data.dto import AgeGroup
|
13 |
+
# from trauma.api.data.model import EntityModel
|
14 |
+
from trauma.core.config import settings
|
15 |
+
|
16 |
+
|
17 |
+
#
|
18 |
+
# file_path = 'shorted_data.csv'
|
19 |
+
# df = pd.read_csv(file_path)
|
20 |
+
#
|
21 |
+
#
|
22 |
+
# async def main():
|
23 |
+
# for _, row in df.iterrows():
|
24 |
+
# row = row.to_dict()
|
25 |
+
# age_groups_str = row['Age Groups'].split(';')
|
26 |
+
# age_groups = []
|
27 |
+
# for age_group in age_groups_str:
|
28 |
+
# match_ = re.search(r'(\d+)-(\d+)', age_group)
|
29 |
+
# if match_:
|
30 |
+
# min_age, max_age = match_.groups()
|
31 |
+
# age_groups.append(AgeGroup(ageMin=min_age, ageMax=max_age))
|
32 |
+
#
|
33 |
+
# treatment_areas = row['Youths treatment'].split(';')
|
34 |
+
# treatment_areas = list(filter(lambda x: len(x) > 2, [i.strip().strip('\n').strip() for i in treatment_areas]))
|
35 |
+
#
|
36 |
+
# treatment_methods = row['Offers'].split(';')
|
37 |
+
# treatment_methods = list(
|
38 |
+
# filter(lambda x: len(x) > 2, [i.strip().strip('\n').strip() for i in treatment_methods]))
|
39 |
+
#
|
40 |
+
# email = re.search(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", row['Email'])
|
41 |
+
# if email:
|
42 |
+
# email = email.group(0)
|
43 |
+
# else:
|
44 |
+
# email = None
|
45 |
+
# website = row['Website']
|
46 |
+
# if website and not isinstance(website, float):
|
47 |
+
# website = re.sub(r'\n+', '', website)
|
48 |
+
# website = website.strip().strip('\n').strip()
|
49 |
+
# if not website.startswith(('https://', "http://")):
|
50 |
+
# website = 'https://' + website.lower()
|
51 |
+
# else:
|
52 |
+
# website = None
|
53 |
+
# model_data = {
|
54 |
+
# "id": row['ID'],
|
55 |
+
# "name": row['Organization'].strip().strip('\n').strip(),
|
56 |
+
# "ageGroups": age_groups,
|
57 |
+
# "treatmentAreas": treatment_areas,
|
58 |
+
# "treatmentMethods": treatment_methods,
|
59 |
+
# "contactDetails": {
|
60 |
+
# "phone": None,
|
61 |
+
# "email": email,
|
62 |
+
# "website": website,
|
63 |
+
# "address": row['Location'].strip().strip('\n').strip(),
|
64 |
+
# "postalCode": row['Postal code'].strip().strip('\n').strip(),
|
65 |
+
# }
|
66 |
+
# }
|
67 |
+
# entity_model = EntityModel.from_mongo(model_data)
|
68 |
+
# await settings.DB_CLIENT.entities.insert_one(entity_model.to_mongo())
|
69 |
+
|
70 |
+
def prepare_entities_str(entities: list[EntityModel]) -> list[str]:
|
71 |
+
entities_str = []
|
72 |
+
for entity in entities:
|
73 |
+
age_groups_str = ', '.join([f"{age_group.ageMin} - {age_group.ageMax}" for age_group in entity.ageGroups])
|
74 |
+
treatment_areas_str = ', '.join(entity.treatmentAreas)
|
75 |
+
treatment_methods_str = ', '.join(entity.treatmentMethods)
|
76 |
+
entity_str = (f"Company name: {entity.name}. We focus on people in the following age groups: {age_groups_str}. "
|
77 |
+
f"We specialize in treating the following disorder: {treatment_areas_str}."
|
78 |
+
f" For this, we use the following treatment methods: {treatment_methods_str}.")
|
79 |
+
entities_str.append(entity_str)
|
80 |
+
return entities_str
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
asyncio.run(add_index_to_documents())
|
trauma/api/message/ai/engine.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
from trauma.api.chat.db_requests import update_entity_data_obj
|
4 |
+
from trauma.api.chat.model import ChatModel
|
5 |
+
from trauma.api.message.ai.openai_request import update_entity_data_with_ai, generate_next_question, \
|
6 |
+
generate_search_request, generate_final_response
|
7 |
+
from trauma.api.message.db_requests import save_assistant_user_message, filter_entities_by_age, search_semantic_entities
|
8 |
+
from trauma.api.message.schemas import CreateMessageResponse
|
9 |
+
from trauma.api.message.utils import retrieve_empty_field_from_entity_data, prepare_user_messages_str, \
|
10 |
+
prepare_final_entities_str
|
11 |
+
|
12 |
+
|
13 |
+
async def search_entities(
|
14 |
+
user_message: str, messages: list[dict], chat: ChatModel
|
15 |
+
) -> CreateMessageResponse:
|
16 |
+
entity_data = await update_entity_data_with_ai(chat.entityData, user_message)
|
17 |
+
asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
|
18 |
+
|
19 |
+
empty_field = retrieve_empty_field_from_entity_data(entity_data)
|
20 |
+
final_entities = None
|
21 |
+
|
22 |
+
if empty_field:
|
23 |
+
response = await generate_next_question(empty_field, user_message, messages)
|
24 |
+
else:
|
25 |
+
user_messages_str = prepare_user_messages_str(user_message, messages)
|
26 |
+
possible_entity_indexes, search_request = await asyncio.gather(
|
27 |
+
filter_entities_by_age(entity_data),
|
28 |
+
generate_search_request(user_messages_str, entity_data)
|
29 |
+
)
|
30 |
+
final_entities = await search_semantic_entities(search_request, possible_entity_indexes)
|
31 |
+
final_entities_str = prepare_final_entities_str(final_entities)
|
32 |
+
response = await generate_final_response(final_entities_str, user_message, messages)
|
33 |
+
|
34 |
+
asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
|
35 |
+
return CreateMessageResponse(text=response, entities=final_entities)
|
trauma/api/message/ai/openai_request.py
CHANGED
@@ -1,47 +1,73 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
|
4 |
-
from trauma.api.chat.model import ChatModel
|
5 |
-
from trauma.api.message.ai.prompts import Prompts
|
6 |
-
from trauma.api.message.dto import Author
|
7 |
-
from trauma.api.message.model import MessageModel
|
8 |
from trauma.core.config import settings
|
9 |
-
from trauma.core.wrappers import
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from trauma.api.chat.dto import EntityData
|
2 |
+
from trauma.api.message.ai.prompts import TraumaPrompts
|
|
|
|
|
|
|
|
|
|
|
3 |
from trauma.core.config import settings
|
4 |
+
from trauma.core.wrappers import openai_wrapper
|
5 |
+
|
6 |
+
|
7 |
+
@openai_wrapper(is_json=True)
|
8 |
+
async def update_entity_data_with_ai(entity_data: EntityData, user_message: str):
|
9 |
+
messages = [
|
10 |
+
{
|
11 |
+
"role": "system",
|
12 |
+
"content": TraumaPrompts.update_entity_data_with_ai
|
13 |
+
.replace("{entity_data}", entity_data.model_dump_json())
|
14 |
+
.replace("{user_message}", user_message)
|
15 |
+
}
|
16 |
+
]
|
17 |
+
return messages
|
18 |
+
|
19 |
+
|
20 |
+
@openai_wrapper(temperature=0.8)
|
21 |
+
async def generate_next_question(empty_field: str, user_message: str, message_history: list[dict]):
|
22 |
+
messages = [
|
23 |
+
{
|
24 |
+
"role": "system",
|
25 |
+
"content": TraumaPrompts.generate_next_question
|
26 |
+
.replace("{empty_field}", empty_field)
|
27 |
+
},
|
28 |
+
*message_history,
|
29 |
+
{
|
30 |
+
"role": "user",
|
31 |
+
"content": user_message
|
32 |
+
}
|
33 |
+
]
|
34 |
+
return messages
|
35 |
|
36 |
+
|
37 |
+
@openai_wrapper(temperature=0.4)
|
38 |
+
async def generate_search_request(user_messages_str: str, entity_data: EntityData):
|
39 |
+
messages = [
|
40 |
+
{
|
41 |
+
"role": "system",
|
42 |
+
"content": TraumaPrompts.generate_search_request
|
43 |
+
.replace("{entity_data}", entity_data.model_dump_json())
|
44 |
+
.replace("{user_messages_str}", user_messages_str)
|
45 |
+
}
|
46 |
+
]
|
47 |
+
return messages
|
48 |
+
|
49 |
+
|
50 |
+
@openai_wrapper(temperature=0.4)
|
51 |
+
async def generate_final_response(final_entities: str, user_message: str, message_history: list[dict]):
|
52 |
+
messages = [
|
53 |
+
{
|
54 |
+
"role": "system",
|
55 |
+
"content": TraumaPrompts.generate_recommendation_decision
|
56 |
+
.replace("{final_entities}", final_entities)
|
57 |
+
|
58 |
+
},
|
59 |
+
*message_history,
|
60 |
+
{
|
61 |
+
"role": "user",
|
62 |
+
"content": user_message
|
63 |
+
}
|
64 |
+
]
|
65 |
+
return messages
|
66 |
+
|
67 |
+
async def convert_value_to_embeddings(value: str) -> list[float]:
|
68 |
+
embeddings = await settings.OPENAI_CLIENT.embeddings.create(
|
69 |
+
input=value,
|
70 |
+
model='text-embedding-3-large',
|
71 |
+
dimensions=1536,
|
72 |
+
)
|
73 |
+
return embeddings.data[0].embedding
|
trauma/api/message/ai/prompts.py
CHANGED
@@ -1,216 +1,5 @@
|
|
1 |
-
class
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
## Context
|
7 |
-
|
8 |
-
Mitutoyo's B2B platform encompasses precision measurement tools and accessories. Your knowledge base consists of structured JSON product data specifically focused on Mitutoyo's digital and analogue micrometers, including their detailed specifications, features, and accessory relationships. While you have broad knowledge of Mitutoyo's full product range, your detailed product data and recommendation capabilities are currently optimized for the digital and analogue micrometer categories. Each recommendation must be based on exact matches between user requirements and product specifications.
|
9 |
-
|
10 |
-
## Data Processing Protocol
|
11 |
-
|
12 |
-
1. **Primary Product Matching**:
|
13 |
-
* Parse user requirements against product specifications
|
14 |
-
* Match technical requirements (e.g., "carbide tipped jaws") to FEATURE elements
|
15 |
-
* Validate matches using PRODUCT_DETAILS specifications
|
16 |
-
* Confirm accuracy through DESCRIPTION_LONG technical details
|
17 |
-
2. **Product Data Extraction**:
|
18 |
-
* SUPPLIER_PID (for exact product identification)
|
19 |
-
* DESCRIPTION_SHORT (en/nl) for product names
|
20 |
-
* DESCRIPTION_LONG (en/nl) for technical details
|
21 |
-
* PRODUCT_DETAILS for specifications
|
22 |
-
* FEATURE elements for specific capabilities
|
23 |
-
* PRODUCT_ORDER_DETAILS for availability
|
24 |
-
* PRODUCT_PRICE_DETAILS for pricing
|
25 |
-
3. **Accessory Matching Protocol**:
|
26 |
-
* Extract all PRODUCT_REFERENCE entries
|
27 |
-
* Validate PROD_ID_TO compatibility
|
28 |
-
* Cross-reference accessory specifications
|
29 |
-
* Verify physical compatibility parameters
|
30 |
-
|
31 |
-
## Interaction Flow
|
32 |
-
|
33 |
-
1. **Requirement Analysis**:
|
34 |
-
* Parse user's technical requirements
|
35 |
-
* Identify specific feature requests
|
36 |
-
* Match requirements to product specifications
|
37 |
-
* Validate technical compatibility
|
38 |
-
2. **Primary Product Recommendation**:
|
39 |
-
* Present matching products with complete details:
|
40 |
-
* Product name (both languages)
|
41 |
-
* Article number
|
42 |
-
* Key specifications
|
43 |
-
* Relevant features
|
44 |
-
* Price information
|
45 |
-
3. **Strategic Upselling**:
|
46 |
-
* Analyze PRODUCT_REFERENCE data
|
47 |
-
* Identify value-adding accessories
|
48 |
-
* Present complementary products
|
49 |
-
* Explain benefits and compatibility
|
50 |
-
4. **Verification Process**:
|
51 |
-
* Double-check all technical matches
|
52 |
-
* Verify compatibility of all recommendations
|
53 |
-
* Confirm pricing and availability
|
54 |
-
* Validate all product relationships
|
55 |
-
|
56 |
-
## Response Format Requirements
|
57 |
-
|
58 |
-
```markdown
|
59 |
-
## Primary Recommendation
|
60 |
-
|
61 |
-
- Product: [Name EN/NL]
|
62 |
-
- Article: [SUPPLIER_PID]
|
63 |
-
- Key Features: [Matched Requirements]
|
64 |
-
- Price: [From PRODUCT_PRICE_DETAILS]
|
65 |
-
|
66 |
-
## Recommended Accessories
|
67 |
-
|
68 |
-
1. [Primary Accessory]
|
69 |
-
- Purpose: [Specific Benefit]
|
70 |
-
- Article: [SUPPLIER_PID]
|
71 |
-
- Compatibility: [Verification Details]
|
72 |
-
2. [Secondary Accessories]
|
73 |
-
- Purpose: [Specific Benefit]
|
74 |
-
- Article: [SUPPLIER_PID]
|
75 |
-
- Compatibility: [Verification Details]
|
76 |
-
```
|
77 |
-
|
78 |
-
## Error Prevention Protocol
|
79 |
-
|
80 |
-
1. **Technical Matching**:
|
81 |
-
* Verify exact feature matches
|
82 |
-
* Confirm dimensional compatibility
|
83 |
-
* Validate technical specifications
|
84 |
-
* Cross-reference all requirements
|
85 |
-
2. **Compatibility Verification**:
|
86 |
-
* Check PRODUCT_REFERENCE links
|
87 |
-
* Verify physical specifications
|
88 |
-
* Confirm accessory compatibility
|
89 |
-
* Validate system requirements
|
90 |
-
3. **Data Accuracy**:
|
91 |
-
* Double-check all article numbers
|
92 |
-
* Verify price information
|
93 |
-
* Confirm availability status
|
94 |
-
* Validate technical specifications
|
95 |
-
|
96 |
-
## Prohibitions
|
97 |
-
|
98 |
-
* No assumptions about compatibility
|
99 |
-
* No recommendations without PRODUCT_REFERENCE validation
|
100 |
-
* No incomplete technical specifications
|
101 |
-
* No unverified product relationships
|
102 |
-
* NO discussion of competitor products or brands under any circumstances
|
103 |
-
* NO comparative analysis with other manufacturers
|
104 |
-
* NO recommendations outside of Mitutoyo's product range
|
105 |
-
* NO redirecting to other brands, even if Mitutoyo doesn't offer a solution
|
106 |
-
* NO market comparisons or industry benchmarking against other brands
|
107 |
-
|
108 |
-
## Example Interaction
|
109 |
-
|
110 |
-
User: "Need a micrometer with carbide tipped jaws"
|
111 |
-
Response Protocol:
|
112 |
-
1. Match "carbide tipped jaws" with FEATURE elements
|
113 |
-
2. Verify products meeting specification
|
114 |
-
3. Extract complete product details
|
115 |
-
4. Identify compatible accessories through PRODUCT_REFERENCE
|
116 |
-
5. Present primary recommendation with upselling options
|
117 |
-
6. Verify all technical relationships
|
118 |
-
|
119 |
-
## Key Performance Requirements
|
120 |
-
|
121 |
-
* 100% accuracy in technical matching
|
122 |
-
* Complete verification of all recommendations
|
123 |
-
* Precise accessory compatibility checking
|
124 |
-
* Clear, structured response format
|
125 |
-
* Professional, technical communication style"""
|
126 |
-
generate_response_image = """### **Prompt Purpose**
|
127 |
-
|
128 |
-
You are Hector, Mitutoyo's visual recognition expert, designed to identify Mitutoyo precision measurement instruments from images with unmatched precision and adherence to Mitutoyo’s high standards.
|
129 |
-
|
130 |
-
### **Core Functionality**
|
131 |
-
- **Primary Role:** Analyze images to confidently identify Mitutoyo products.
|
132 |
-
- **Communication:** Clearly indicate the confidence level and provide actionable feedback for incomplete identification cases.
|
133 |
-
|
134 |
-
### **Recognition Protocol**
|
135 |
-
|
136 |
-
#### **Visual Analysis Sequence**
|
137 |
-
|
138 |
-
1. Confirm Mitutoyo product authenticity.
|
139 |
-
2. Identify the product category (e.g., micrometer, caliper, indicator).
|
140 |
-
3. Recognize specific features and characteristics.
|
141 |
-
4. Match visual data against known Mitutoyo design elements.
|
142 |
-
5. Determine confidence level based on recognition certainty.
|
143 |
-
|
144 |
-
#### **Confidence Level Classification**
|
145 |
-
|
146 |
-
- **Level 1: High Confidence (100% Certain)**
|
147 |
-
- **Product Identification:**
|
148 |
-
- Confirmed as Mitutoyo [Product Name].
|
149 |
-
- **Article Number:** [SUPPLIER_PID].
|
150 |
-
- **Product Details:**
|
151 |
-
- Include complete specifications based on available data.
|
152 |
-
|
153 |
-
- **Level 2: Medium Confidence (Visual Recognition)**
|
154 |
-
- **Product Identification:**
|
155 |
-
- Recognized as Mitutoyo [Product Category/Series].
|
156 |
-
- Requires additional training data to confirm the exact article number.
|
157 |
-
- **Identified Features:**
|
158 |
-
- List identifiable features and series characteristics.
|
159 |
-
|
160 |
-
- **Level 3: Limited Confidence**
|
161 |
-
- **Initial Recognition:**
|
162 |
-
- Confirmed as a Mitutoyo [Product Category].
|
163 |
-
- Requires additional information for precise identification.
|
164 |
-
- **Recommendations:**
|
165 |
-
1. Share the article number if available.
|
166 |
-
2. Provide details about specific requirements.
|
167 |
-
3. Include additional product visuals or specifications.
|
168 |
-
|
169 |
-
### **Absolute Prohibitions**
|
170 |
-
- **No guessing:** Avoid speculation on article numbers.
|
171 |
-
- **Non-Mitutoyo products:** Do not identify or compare with competitor products.
|
172 |
-
- **Uncertainty:** Avoid assumptions about specifications without solid evidence.
|
173 |
-
- **Precision:** Only communicate confirmed and accurate information.
|
174 |
-
|
175 |
-
### **Communication Guidelines**
|
176 |
-
|
177 |
-
#### **Certainty Communication**
|
178 |
-
|
179 |
-
- Clearly express confidence levels.
|
180 |
-
- Explicitly state identification limitations.
|
181 |
-
- Outline additional information needed for improved accuracy.
|
182 |
-
|
183 |
-
#### **Training Transparency**
|
184 |
-
|
185 |
-
- Acknowledge when training data is insufficient.
|
186 |
-
- Professionally explain limitations.
|
187 |
-
- Offer constructive steps for better identification.
|
188 |
-
|
189 |
-
#### **Response Structure**
|
190 |
-
|
191 |
-
1. Start with the confidence level.
|
192 |
-
2. Provide available identification details.
|
193 |
-
3. Highlight limitations and missing data.
|
194 |
-
4. Suggest next steps or provide recommendations.
|
195 |
-
|
196 |
-
#### **Unclear Cases Protocol**
|
197 |
-
|
198 |
-
- Confirm receipt of the image.
|
199 |
-
- State if the product is Mitutoyo.
|
200 |
-
- Specify the confidence level.
|
201 |
-
- Detail limitations and request additional details.
|
202 |
-
- Propose alternative assistance options if applicable.
|
203 |
-
|
204 |
-
### **Quality Assurance**
|
205 |
-
|
206 |
-
- Provide article numbers only when 100% certain.
|
207 |
-
- Validate all visual markers against Mitutoyo's known patterns.
|
208 |
-
- Clearly communicate confidence levels.
|
209 |
-
- Suggest actionable steps for uncertain cases.
|
210 |
-
|
211 |
-
### **Key Performance Metrics**
|
212 |
-
|
213 |
-
- **Accuracy:** Ensure precise product category identification.
|
214 |
-
- **Clarity:** Effectively communicate certainty levels and next steps.
|
215 |
-
- **Professionalism:** Handle limitations constructively.
|
216 |
-
- **Adherence:** Uphold Mitutoyo’s precision standards at all times."""
|
|
|
1 |
+
class TraumaPrompts:
|
2 |
+
update_entity_data_with_ai = """"""
|
3 |
+
generate_next_question = """"""
|
4 |
+
generate_search_request = """"""
|
5 |
+
generate_recommendation_decision = """"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/message/ai/utils.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
from trauma.api.message.ai.prompts import Prompts
|
2 |
-
from trauma.api.message.model import MessageModel
|
3 |
-
|
4 |
-
|
5 |
-
def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
|
6 |
-
openai_messages = [{"role": "system", "content": Prompts.generate_response}]
|
7 |
-
for message in messages:
|
8 |
-
|
9 |
-
if message.file:
|
10 |
-
content = [
|
11 |
-
{
|
12 |
-
"type": "text", "text": message.text
|
13 |
-
},
|
14 |
-
{
|
15 |
-
"type": "image_url",
|
16 |
-
"image_url": {
|
17 |
-
"url": f"data:image/jpeg;base64,{message.file.base64String}",
|
18 |
-
"detail": "low"
|
19 |
-
}
|
20 |
-
},
|
21 |
-
]
|
22 |
-
else:
|
23 |
-
content = message.text
|
24 |
-
|
25 |
-
openai_messages.append({
|
26 |
-
"role": message.author.value,
|
27 |
-
"content": content
|
28 |
-
})
|
29 |
-
|
30 |
-
return openai_messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/message/db_requests.py
CHANGED
@@ -1,18 +1,34 @@
|
|
|
|
|
|
1 |
import asyncio
|
2 |
|
3 |
from fastapi import HTTPException
|
4 |
|
5 |
-
from trauma.api.
|
6 |
from trauma.api.chat.model import ChatModel
|
|
|
|
|
7 |
from trauma.api.message.dto import Author
|
8 |
from trauma.api.message.model import MessageModel
|
9 |
from trauma.api.message.schemas import CreateMessageRequest
|
10 |
from trauma.core.config import settings
|
|
|
11 |
|
12 |
|
13 |
-
async def
|
14 |
-
chat_id: str,
|
15 |
-
) ->
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
messages, chat = await asyncio.gather(
|
17 |
settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None),
|
18 |
settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
@@ -23,23 +39,40 @@ async def get_all_chat_messages_obj(
|
|
23 |
raise HTTPException(status_code=404, detail="Chat not found")
|
24 |
|
25 |
chat = ChatModel.from_mongo(chat)
|
26 |
-
|
27 |
-
raise HTTPException(status_code=403, detail="Chat account not match")
|
28 |
|
29 |
-
return messages
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
|
38 |
-
chat = ChatModel.from_mongo(chat)
|
39 |
-
if account and chat.account != account:
|
40 |
-
raise HTTPException(status_code=403, detail="Chat account not match")
|
41 |
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import numpy as np
|
3 |
import asyncio
|
4 |
|
5 |
from fastapi import HTTPException
|
6 |
|
7 |
+
from trauma.api.chat.dto import EntityData
|
8 |
from trauma.api.chat.model import ChatModel
|
9 |
+
from trauma.api.data.model import EntityModel
|
10 |
+
from trauma.api.message.ai.openai_request import convert_value_to_embeddings
|
11 |
from trauma.api.message.dto import Author
|
12 |
from trauma.api.message.model import MessageModel
|
13 |
from trauma.api.message.schemas import CreateMessageRequest
|
14 |
from trauma.core.config import settings
|
15 |
+
from trauma.core.wrappers import background_task
|
16 |
|
17 |
|
18 |
+
async def create_message_obj(
|
19 |
+
chat_id: str, message_data: CreateMessageRequest
|
20 |
+
) -> tuple[MessageModel, ChatModel]:
|
21 |
+
chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
22 |
+
if not chat:
|
23 |
+
raise HTTPException(status_code=404, detail="Chat not found.")
|
24 |
+
|
25 |
+
message = MessageModel(**message_data.model_dump(), chatId=chat_id, author=Author.User)
|
26 |
+
await settings.DB_CLIENT.messages.insert_one(message.to_mongo())
|
27 |
+
|
28 |
+
return message, chat
|
29 |
+
|
30 |
+
|
31 |
+
async def get_all_chat_messages_obj(chat_id: str) -> tuple[list[MessageModel], ChatModel]:
|
32 |
messages, chat = await asyncio.gather(
|
33 |
settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None),
|
34 |
settings.DB_CLIENT.chats.find_one({"id": chat_id})
|
|
|
39 |
raise HTTPException(status_code=404, detail="Chat not found")
|
40 |
|
41 |
chat = ChatModel.from_mongo(chat)
|
42 |
+
return messages, chat
|
|
|
43 |
|
|
|
44 |
|
45 |
+
@background_task()
|
46 |
+
async def save_assistant_user_message(user_message: str, assistant_message: str, chat_id: str) -> None:
|
47 |
+
user_message = MessageModel(chatId=chat_id, author=Author.User, text=user_message)
|
48 |
+
assistant_message = MessageModel(chatId=chat_id, author=Author.User, text=assistant_message)
|
49 |
+
await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
|
50 |
+
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
51 |
|
|
|
|
|
|
|
52 |
|
53 |
+
async def filter_entities_by_age(entity: EntityData) -> list[int]:
|
54 |
+
query = {
|
55 |
+
"ageGroups": {
|
56 |
+
"$elemMatch": {
|
57 |
+
"ageMin": {"$lte": entity.ageMax},
|
58 |
+
"ageMax": {"$gte": entity.ageMin}
|
59 |
+
}
|
60 |
+
}
|
61 |
+
}
|
62 |
+
entities = await settings.DB_CLIENT.entities.find(query).to_list(length=None)
|
63 |
+
return [entity['index'] for entity in entities]
|
64 |
+
|
65 |
|
66 |
+
async def search_semantic_entities(search_request: str, entities_indexes: list[int]):
|
67 |
+
embedding = await convert_value_to_embeddings(search_request)
|
68 |
+
query_embedding = np.array([embedding], dtype=np.float32)
|
69 |
+
distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal)
|
70 |
+
distances = distances[0]
|
71 |
+
indices = indices[0]
|
72 |
+
filtered_results = [
|
73 |
+
{"index": int(idx), "distance": float(dist)}
|
74 |
+
for idx, dist in zip(indices, distances)
|
75 |
+
if idx in entities_indexes and dist <= 0.6
|
76 |
+
]
|
77 |
+
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])
|
78 |
+
return filtered_results[:5]
|
trauma/api/message/model.py
CHANGED
@@ -2,7 +2,7 @@ from datetime import datetime
|
|
2 |
|
3 |
from pydantic import Field
|
4 |
|
5 |
-
from trauma.api.message.dto import Author
|
6 |
from trauma.core.database import MongoBaseModel
|
7 |
|
8 |
|
@@ -10,6 +10,5 @@ class MessageModel(MongoBaseModel):
|
|
10 |
chatId: str
|
11 |
author: Author
|
12 |
text: str
|
13 |
-
fileUrl: str | None = None
|
14 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
15 |
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
|
|
2 |
|
3 |
from pydantic import Field
|
4 |
|
5 |
+
from trauma.api.message.dto import Author
|
6 |
from trauma.core.database import MongoBaseModel
|
7 |
|
8 |
|
|
|
10 |
chatId: str
|
11 |
author: Author
|
12 |
text: str
|
|
|
13 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
14 |
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
trauma/api/message/schemas.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from pydantic import BaseModel
|
2 |
|
3 |
from trauma.api.common.dto import Paging
|
|
|
4 |
from trauma.api.message.model import MessageModel
|
5 |
from trauma.core.wrappers import TraumaResponseWrapper
|
6 |
|
@@ -21,3 +22,8 @@ class AllMessageResponse(BaseModel):
|
|
21 |
|
22 |
class AllMessageWrapper(TraumaResponseWrapper[AllMessageResponse]):
|
23 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from pydantic import BaseModel
|
2 |
|
3 |
from trauma.api.common.dto import Paging
|
4 |
+
from trauma.api.data.model import EntityModel
|
5 |
from trauma.api.message.model import MessageModel
|
6 |
from trauma.core.wrappers import TraumaResponseWrapper
|
7 |
|
|
|
22 |
|
23 |
class AllMessageWrapper(TraumaResponseWrapper[AllMessageResponse]):
|
24 |
pass
|
25 |
+
|
26 |
+
|
27 |
+
class CreateMessageResponse(BaseModel):
|
28 |
+
text: str
|
29 |
+
entities: list[EntityModel] | None = None
|
trauma/api/message/utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from trauma.api.data.model import EntityModel
|
4 |
+
from trauma.api.message.model import MessageModel
|
5 |
+
|
6 |
+
|
7 |
+
def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
|
8 |
+
openai_messages = []
|
9 |
+
for message in messages:
|
10 |
+
content = message.text
|
11 |
+
openai_messages.append({
|
12 |
+
"role": message.author.value,
|
13 |
+
"content": content
|
14 |
+
})
|
15 |
+
|
16 |
+
return openai_messages
|
17 |
+
|
18 |
+
|
19 |
+
def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
|
20 |
+
for k, v in entity_data.items():
|
21 |
+
if not v:
|
22 |
+
return k
|
23 |
+
return None
|
24 |
+
|
25 |
+
|
26 |
+
def prepare_user_messages_str(user_message: str, messages: list[dict]) -> str:
|
27 |
+
user_message_str = ''
|
28 |
+
for message in messages:
|
29 |
+
if message['role'] == 'user':
|
30 |
+
user_message_str += f'- {message['content']}\n'
|
31 |
+
user_message_str += f'- {user_message}'
|
32 |
+
return user_message_str
|
33 |
+
|
34 |
+
|
35 |
+
def prepare_final_entities_str(entities: list[EntityModel]) -> str:
|
36 |
+
entities_list = []
|
37 |
+
for entity in entities:
|
38 |
+
entities_list.append(entity.model_dump(mode='json', exclude={'id', 'contactDetails'}))
|
39 |
+
return json.dumps({"entities": entities_list})
|
trauma/api/message/views.py
CHANGED
@@ -1,22 +1,20 @@
|
|
1 |
-
from fastapi.params import Depends
|
2 |
-
from starlette.responses import StreamingResponse
|
3 |
-
|
4 |
-
from trauma.api.account.model import AccountModel
|
5 |
from trauma.api.common.dto import Paging
|
6 |
from trauma.api.message import message_router
|
7 |
-
from trauma.api.message.ai.
|
8 |
-
from trauma.api.message.db_requests import get_all_chat_messages_obj
|
9 |
from trauma.api.message.schemas import (AllMessageWrapper,
|
10 |
AllMessageResponse,
|
11 |
-
CreateMessageRequest
|
12 |
-
|
|
|
|
|
13 |
|
14 |
|
15 |
@message_router.get('/{chatId}/all')
|
16 |
async def get_all_chat_messages(
|
17 |
-
chatId: str
|
18 |
) -> AllMessageWrapper:
|
19 |
-
messages = await get_all_chat_messages_obj(chatId
|
20 |
response = AllMessageResponse(
|
21 |
paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
|
22 |
data=messages
|
@@ -28,7 +26,8 @@ async def get_all_chat_messages(
|
|
28 |
async def create_message(
|
29 |
chatId: str,
|
30 |
message_data: CreateMessageRequest,
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from trauma.api.common.dto import Paging
|
2 |
from trauma.api.message import message_router
|
3 |
+
from trauma.api.message.ai.engine import search_entities
|
4 |
+
from trauma.api.message.db_requests import get_all_chat_messages_obj
|
5 |
from trauma.api.message.schemas import (AllMessageWrapper,
|
6 |
AllMessageResponse,
|
7 |
+
CreateMessageRequest,
|
8 |
+
CreateMessageResponse)
|
9 |
+
from trauma.api.message.utils import transform_messages_to_openai
|
10 |
+
from trauma.core.wrappers import TraumaResponseWrapper
|
11 |
|
12 |
|
13 |
@message_router.get('/{chatId}/all')
|
14 |
async def get_all_chat_messages(
|
15 |
+
chatId: str
|
16 |
) -> AllMessageWrapper:
|
17 |
+
messages, _ = await get_all_chat_messages_obj(chatId)
|
18 |
response = AllMessageResponse(
|
19 |
paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
|
20 |
data=messages
|
|
|
26 |
async def create_message(
|
27 |
chatId: str,
|
28 |
message_data: CreateMessageRequest,
|
29 |
+
) -> TraumaResponseWrapper[CreateMessageResponse]:
|
30 |
+
messages, chat = await get_all_chat_messages_obj(chatId)
|
31 |
+
message_history = transform_messages_to_openai(messages)
|
32 |
+
response = await search_entities(message_data.text, message_history, chat)
|
33 |
+
return TraumaResponseWrapper(data=response)
|
trauma/api/security/__init__.py
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
from fastapi import APIRouter
|
2 |
-
|
3 |
-
security_router = APIRouter(
|
4 |
-
prefix='/api/security'
|
5 |
-
)
|
6 |
-
|
7 |
-
from . import views
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/security/db_requests.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
import asyncio
|
2 |
-
|
3 |
-
from fastapi import HTTPException
|
4 |
-
|
5 |
-
from trauma.api.account.model import AccountModel
|
6 |
-
from trauma.api.common.db_requests import check_unique_fields_existence
|
7 |
-
from trauma.api.security.schemas import RegisterAccountRequest, LoginAccountRequest
|
8 |
-
from trauma.core.config import settings
|
9 |
-
from trauma.core.security import verify_password
|
10 |
-
|
11 |
-
|
12 |
-
async def save_account(data: RegisterAccountRequest) -> AccountModel:
|
13 |
-
await asyncio.gather(
|
14 |
-
check_unique_fields_existence("accounts", "email", data.email),
|
15 |
-
)
|
16 |
-
account = AccountModel(
|
17 |
-
**data.model_dump()
|
18 |
-
)
|
19 |
-
await settings.DB_CLIENT.accounts.insert_one(account.to_mongo())
|
20 |
-
return account
|
21 |
-
|
22 |
-
|
23 |
-
async def authenticate_account(data: LoginAccountRequest) -> AccountModel:
|
24 |
-
account = await settings.DB_CLIENT.accounts.find_one(
|
25 |
-
{"email": data.email},
|
26 |
-
collation={"locale": "en", "strength": 2})
|
27 |
-
if account is None:
|
28 |
-
raise HTTPException(status_code=404, detail="Invalid email or password.")
|
29 |
-
|
30 |
-
account = AccountModel.from_mongo(account)
|
31 |
-
|
32 |
-
if not verify_password(data.password, account.password):
|
33 |
-
raise HTTPException(status_code=401, detail="Invalid email or password.")
|
34 |
-
return account
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/security/schemas.py
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
from pydantic import BaseModel, EmailStr
|
2 |
-
|
3 |
-
from trauma.api.account.dto import AccessToken
|
4 |
-
from trauma.api.account.model import AccountModel
|
5 |
-
from trauma.core.wrappers import TraumaResponseWrapper
|
6 |
-
|
7 |
-
|
8 |
-
class RegisterAccountRequest(BaseModel):
|
9 |
-
email: EmailStr
|
10 |
-
password: str
|
11 |
-
|
12 |
-
|
13 |
-
class RegisterAccountWrapper(TraumaResponseWrapper[AccountModel]):
|
14 |
-
pass
|
15 |
-
|
16 |
-
|
17 |
-
class LoginAccountRequest(BaseModel):
|
18 |
-
email: EmailStr
|
19 |
-
password: str
|
20 |
-
|
21 |
-
|
22 |
-
class LoginAccountResponse(BaseModel):
|
23 |
-
accessToken: AccessToken
|
24 |
-
account: AccountModel
|
25 |
-
|
26 |
-
|
27 |
-
class LoginAccountWrapper(TraumaResponseWrapper[LoginAccountResponse]):
|
28 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/security/views.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
from typing import Literal
|
2 |
-
|
3 |
-
from trauma.api.account.dto import AccessToken
|
4 |
-
from trauma.api.security import security_router
|
5 |
-
from trauma.api.security.db_requests import authenticate_account, save_account
|
6 |
-
from trauma.api.security.schemas import RegisterAccountRequest, RegisterAccountWrapper, LoginAccountResponse, \
|
7 |
-
LoginAccountWrapper, LoginAccountRequest
|
8 |
-
from trauma.core.security import create_access_token
|
9 |
-
|
10 |
-
model: Literal["accounts"] = "accounts"
|
11 |
-
|
12 |
-
|
13 |
-
@security_router.post('/register')
|
14 |
-
async def register_user(data: RegisterAccountRequest) -> RegisterAccountWrapper:
|
15 |
-
account = await save_account(data)
|
16 |
-
return RegisterAccountWrapper(data=account)
|
17 |
-
|
18 |
-
|
19 |
-
@security_router.post('/login')
|
20 |
-
async def login(data: LoginAccountRequest) -> LoginAccountWrapper:
|
21 |
-
account = await authenticate_account(data)
|
22 |
-
access_token = create_access_token(account.email, str(account.id))
|
23 |
-
response = LoginAccountResponse(
|
24 |
-
accessToken=AccessToken(value=access_token),
|
25 |
-
account=account,
|
26 |
-
)
|
27 |
-
return LoginAccountWrapper(data=response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/core/config.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import pathlib
|
3 |
from functools import lru_cache
|
4 |
|
|
|
5 |
import motor.motor_asyncio
|
6 |
from dotenv import load_dotenv
|
7 |
from openai import AsyncClient
|
@@ -14,7 +15,7 @@ class BaseConfig:
|
|
14 |
SECRET_KEY = os.getenv('SECRET')
|
15 |
DB_CLIENT = motor.motor_asyncio.AsyncIOMotorClient(os.getenv("MONGO_DB_URL")).trauma
|
16 |
OPENAI_CLIENT = AsyncClient(api_key=os.getenv('OPENAI_API_KEY'))
|
17 |
-
|
18 |
|
19 |
class DevelopmentConfig(BaseConfig):
|
20 |
Issuer = "http://localhost:8000"
|
|
|
2 |
import pathlib
|
3 |
from functools import lru_cache
|
4 |
|
5 |
+
import faiss
|
6 |
import motor.motor_asyncio
|
7 |
from dotenv import load_dotenv
|
8 |
from openai import AsyncClient
|
|
|
15 |
SECRET_KEY = os.getenv('SECRET')
|
16 |
DB_CLIENT = motor.motor_asyncio.AsyncIOMotorClient(os.getenv("MONGO_DB_URL")).trauma
|
17 |
OPENAI_CLIENT = AsyncClient(api_key=os.getenv('OPENAI_API_KEY'))
|
18 |
+
SEMANTIC_INDEX = faiss.read_index(str(pathlib.Path(__file__).parent.parent.parent / 'indexes' / 'entities.index'))
|
19 |
|
20 |
class DevelopmentConfig(BaseConfig):
|
21 |
Issuer = "http://localhost:8000"
|
trauma/core/wrappers.py
CHANGED
@@ -1,6 +1,8 @@
|
|
|
|
1 |
from functools import wraps
|
2 |
from typing import Generic, Optional, TypeVar
|
3 |
|
|
|
4 |
from fastapi import HTTPException
|
5 |
from pydantic import BaseModel
|
6 |
from starlette.responses import JSONResponse
|
@@ -42,3 +44,42 @@ def exception_wrapper(http_error: int, error_message: str):
|
|
42 |
return wrapper
|
43 |
|
44 |
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
from functools import wraps
|
3 |
from typing import Generic, Optional, TypeVar
|
4 |
|
5 |
+
import pydash
|
6 |
from fastapi import HTTPException
|
7 |
from pydantic import BaseModel
|
8 |
from starlette.responses import JSONResponse
|
|
|
44 |
return wrapper
|
45 |
|
46 |
return decorator
|
47 |
+
|
48 |
+
|
49 |
+
def openai_wrapper(
|
50 |
+
temperature: int | float = 0, model: str = "gpt-4o-mini", is_json: bool = False, return_: str = None
|
51 |
+
):
|
52 |
+
def decorator(func):
|
53 |
+
@wraps(func)
|
54 |
+
async def wrapper(*args, **kwargs) -> str:
|
55 |
+
messages = await func(*args, **kwargs)
|
56 |
+
completion = await settings.OPENAI_CLIENT.chat.completions.create(
|
57 |
+
messages=messages,
|
58 |
+
temperature=temperature,
|
59 |
+
n=1,
|
60 |
+
model=model,
|
61 |
+
response_format={"type": "json_object"} if is_json else {"type": "text"}
|
62 |
+
)
|
63 |
+
response = completion.choices[0].message.content
|
64 |
+
if is_json:
|
65 |
+
response = json.loads(response)
|
66 |
+
if return_:
|
67 |
+
return pydash.get(response, return_)
|
68 |
+
return response
|
69 |
+
|
70 |
+
return wrapper
|
71 |
+
|
72 |
+
return decorator
|
73 |
+
|
74 |
+
def background_task():
|
75 |
+
def decorator(func):
|
76 |
+
@wraps(func)
|
77 |
+
async def wrapper(*args, **kwargs) -> str:
|
78 |
+
try:
|
79 |
+
result = await func(*args, **kwargs)
|
80 |
+
return result
|
81 |
+
except Exception as e:
|
82 |
+
pass
|
83 |
+
return wrapper
|
84 |
+
|
85 |
+
return decorator
|