brestok commited on
Commit
9150f8e
·
1 Parent(s): 50553ea

finish backend

Browse files
.gitignore CHANGED
@@ -5,7 +5,7 @@ venv/
5
  .idea/
6
  *.log
7
  *.egg-info/
8
- pip-wheel-metadata/
9
  .env
10
  .DS_Store
11
  static/
 
5
  .idea/
6
  *.log
7
  *.egg-info/
8
+ pip-wheel-EntityData/
9
  .env
10
  .DS_Store
11
  static/
requirements.txt CHANGED
@@ -33,4 +33,4 @@ typing_extensions==4.12.2
33
  uvicorn==0.32.1
34
  uvloop==0.21.0
35
  watchfiles==1.0.0
36
- websockets==14.1
 
33
  uvicorn==0.32.1
34
  uvloop==0.21.0
35
  watchfiles==1.0.0
36
+ websockets==14.1
test.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from translate import Translator
3
+
4
+
5
+ def convert_and_translate_headers(input_file: str, output_file: str, sheet_name: str = None):
6
+ try:
7
+ # Читаем файл Excel
8
+ data = pd.read_excel(input_file, sheet_name=sheet_name)
9
+
10
+ # Инициализируем переводчик
11
+ translator = Translator(from_lang='nl', to_lang='en')
12
+
13
+ # Переводим названия колонок
14
+ translated_columns = {col: translator.translate(col) for col in data.columns}
15
+ data.rename(columns=translated_columns, inplace=True)
16
+
17
+ # Сохраняем преобразованные данные в CSV
18
+ data.to_csv(output_file, index=False)
19
+ print(f"Файл успешно конвертирован и сохранен: {output_file}")
20
+ except Exception as e:
21
+ print(f"Произошла ошибка: {e}")
22
+
23
+
24
+ input_xlsx = "test.xlsx" # Путь к входному .xlsx файлу
25
+ output_csv = "translated_output.csv" # Путь к выходному .csv файлу
26
+ sheet = "Sheet1" # Укажите имя листа, если нужно
27
+
28
+ convert_and_translate_headers(input_xlsx, output_csv, sheet)
trauma/__init__.py CHANGED
@@ -12,17 +12,12 @@ from trauma.core.wrappers import TraumaResponseWrapper, ErrorTraumaResponse
12
  def create_app() -> FastAPI:
13
  app = FastAPI()
14
 
15
- from trauma.api.account import account_router
16
- app.include_router(account_router, tags=['account'])
17
-
18
  from trauma.api.chat import chat_router
19
  app.include_router(chat_router, tags=['chat'])
20
 
21
  from trauma.api.message import message_router
22
  app.include_router(message_router, tags=['message'])
23
 
24
- from trauma.api.security import security_router
25
- app.include_router(security_router, tags=['security'])
26
 
27
  app.add_middleware(
28
  CORSMiddleware,
 
12
  def create_app() -> FastAPI:
13
  app = FastAPI()
14
 
 
 
 
15
  from trauma.api.chat import chat_router
16
  app.include_router(chat_router, tags=['chat'])
17
 
18
  from trauma.api.message import message_router
19
  app.include_router(message_router, tags=['message'])
20
 
 
 
21
 
22
  app.add_middleware(
23
  CORSMiddleware,
trauma/api/account/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- from fastapi import APIRouter
2
-
3
- account_router = APIRouter(
4
- prefix='/api/account'
5
- )
6
-
7
- from . import views
 
 
 
 
 
 
 
 
trauma/api/account/dto.py DELETED
@@ -1,6 +0,0 @@
1
- from pydantic import BaseModel
2
-
3
-
4
- class AccessToken(BaseModel):
5
- type: str = "Bearer"
6
- value: str
 
 
 
 
 
 
 
trauma/api/account/model.py DELETED
@@ -1,32 +0,0 @@
1
- from datetime import datetime
2
-
3
- from pydantic import field_validator, Field, EmailStr
4
- from passlib.context import CryptContext
5
-
6
- from trauma.core.database import MongoBaseModel
7
-
8
-
9
- class AccountModel(MongoBaseModel):
10
- email: EmailStr
11
- password: str | None = Field(exclude=True, default=None)
12
-
13
- datetimeInserted: datetime = Field(default_factory=datetime.now)
14
- datetimeUpdated: datetime = Field(default_factory=datetime.now)
15
-
16
- @field_validator("datetimeUpdated", mode="before", check_fields=False)
17
- @classmethod
18
- def validate_datetimeUpdated(cls, v):
19
- return v or datetime.now()
20
-
21
- @field_validator('password', mode='before', check_fields=False)
22
- @classmethod
23
- def set_password_hash(cls, v):
24
- if not v.startswith("$2b$"):
25
- return CryptContext(schemes=["bcrypt"], deprecated="auto").hash(v)
26
- return v
27
-
28
- class Config:
29
- arbitrary_types_allowed = True
30
- json_encoders = {
31
- datetime: lambda dt: dt.isoformat()
32
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trauma/api/account/schemas.py DELETED
@@ -1,6 +0,0 @@
1
- from trauma.api.account.model import AccountModel
2
- from trauma.core.wrappers import TraumaResponseWrapper
3
-
4
-
5
- class AccountWrapper(TraumaResponseWrapper[AccountModel]):
6
- pass
 
 
 
 
 
 
 
trauma/api/account/views.py DELETED
@@ -1,13 +0,0 @@
1
- from fastapi import Depends
2
-
3
- from trauma.api.account import account_router
4
- from trauma.api.account.model import AccountModel
5
- from trauma.api.account.schemas import AccountWrapper
6
- from trauma.core.security import PermissionDependency
7
-
8
-
9
- @account_router.get('')
10
- async def get_account(
11
- account: AccountModel = Depends(PermissionDependency())
12
- ) -> AccountWrapper:
13
- return AccountWrapper(data=account)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trauma/api/chat/db_requests.py CHANGED
@@ -2,62 +2,65 @@ import asyncio
2
 
3
  from fastapi import HTTPException
4
 
5
- from trauma.api.account.model import AccountModel
6
  from trauma.api.chat.model import ChatModel
7
  from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest
8
  from trauma.core.config import settings
 
9
 
10
 
11
- async def get_chat_obj(chat_id: str, account: AccountModel | None) -> ChatModel:
12
  chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
13
  if not chat:
14
  raise HTTPException(status_code=404, detail="Chat not found")
15
  chat = ChatModel.from_mongo(chat)
16
- if account and chat.account != account:
17
- raise HTTPException(status_code=403, detail="Chat account not match")
18
  return chat
19
 
20
 
21
- async def create_chat_obj(chat_request: CreateChatRequest, account: AccountModel | None) -> ChatModel:
22
- chat = ChatModel(model=chat_request.model, account=account)
23
  await settings.DB_CLIENT.chats.insert_one(chat.to_mongo())
24
  return chat
25
 
26
 
27
- async def delete_chat_obj(chat_id: str, account: AccountModel | None) -> None:
28
  chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
29
  if not chat:
30
  raise HTTPException(status_code=404, detail="Chat not found")
31
  chat = ChatModel.from_mongo(chat)
32
- if account and chat.account != account:
33
- raise HTTPException(status_code=403, detail="Chat account not match")
34
  await settings.DB_CLIENT.chats.delete_one({"id": chat_id})
35
 
36
 
37
- async def update_chat_obj_title(chatId: str, chat_request: ChatTitleRequest, account: AccountModel | None) -> ChatModel:
38
  chat = await settings.DB_CLIENT.chats.find_one({"id": chatId})
39
  if not chat:
40
  raise HTTPException(status_code=404, detail="Chat not found")
41
 
42
  chat = ChatModel.from_mongo(chat)
43
- if account and chat.account != account:
44
- raise HTTPException(status_code=403, detail="Chat account not match")
45
 
46
  chat.title = chat_request.title
47
  await settings.DB_CLIENT.chats.update_one({"id": chatId}, {"$set": chat.to_mongo()})
48
  return chat
49
 
50
 
51
- async def get_all_chats_obj(page_size: int, page_index: int, account: AccountModel) -> tuple[list[ChatModel], int]:
52
- query = {"account.id": account.id}
53
  skip = page_size * page_index
54
  objects, total_count = await asyncio.gather(
55
  settings.DB_CLIENT.chats
56
- .find(query)
57
  .sort("_id", -1)
58
  .skip(skip)
59
  .limit(page_size)
60
  .to_list(length=page_size),
61
- settings.DB_CLIENT.chats.count_documents(query),
62
  )
63
  return objects, total_count
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from fastapi import HTTPException
4
 
 
5
  from trauma.api.chat.model import ChatModel
6
  from trauma.api.chat.schemas import CreateChatRequest, ChatTitleRequest
7
  from trauma.core.config import settings
8
+ from trauma.core.wrappers import background_task
9
 
10
 
11
+ async def get_chat_obj(chat_id: str) -> ChatModel:
12
  chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
13
  if not chat:
14
  raise HTTPException(status_code=404, detail="Chat not found")
15
  chat = ChatModel.from_mongo(chat)
 
 
16
  return chat
17
 
18
 
19
+ async def create_chat_obj(chat_request: CreateChatRequest) -> ChatModel:
20
+ chat = ChatModel(model=chat_request.model)
21
  await settings.DB_CLIENT.chats.insert_one(chat.to_mongo())
22
  return chat
23
 
24
 
25
+ async def delete_chat_obj(chat_id: str) -> None:
26
  chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
27
  if not chat:
28
  raise HTTPException(status_code=404, detail="Chat not found")
29
  chat = ChatModel.from_mongo(chat)
 
 
30
  await settings.DB_CLIENT.chats.delete_one({"id": chat_id})
31
 
32
 
33
+ async def update_chat_obj_title(chatId: str, chat_request: ChatTitleRequest) -> ChatModel:
34
  chat = await settings.DB_CLIENT.chats.find_one({"id": chatId})
35
  if not chat:
36
  raise HTTPException(status_code=404, detail="Chat not found")
37
 
38
  chat = ChatModel.from_mongo(chat)
 
 
39
 
40
  chat.title = chat_request.title
41
  await settings.DB_CLIENT.chats.update_one({"id": chatId}, {"$set": chat.to_mongo()})
42
  return chat
43
 
44
 
45
+ async def get_all_chats_obj(page_size: int, page_index: int) -> tuple[list[ChatModel], int]:
 
46
  skip = page_size * page_index
47
  objects, total_count = await asyncio.gather(
48
  settings.DB_CLIENT.chats
49
+ .find({})
50
  .sort("_id", -1)
51
  .skip(skip)
52
  .limit(page_size)
53
  .to_list(length=page_size),
54
+ settings.DB_CLIENT.chats.count_documents({}),
55
  )
56
  return objects, total_count
57
+
58
+
59
+ @background_task()
60
+ async def update_entity_data_obj(entity_data: dict, chat_id: str) -> None:
61
+ await settings.DB_CLIENT.chats.update_one(
62
+ {"id": chat_id},
63
+ {"$set": {
64
+ "entityData": entity_data,
65
+ }}
66
+ )
trauma/api/chat/dto.py CHANGED
@@ -1,8 +1,16 @@
1
  from enum import Enum
2
 
 
 
3
 
4
  class ModelType(Enum):
5
  gpt_4o = "gpt-4o"
6
  gpt_4o_mini = "gpt-4o-mini"
7
  o1_mini = "o1-mini"
8
  o1_preview = "o1-preview"
 
 
 
 
 
 
 
1
  from enum import Enum
2
 
3
+ from pydantic import BaseModel
4
+
5
 
6
  class ModelType(Enum):
7
  gpt_4o = "gpt-4o"
8
  gpt_4o_mini = "gpt-4o-mini"
9
  o1_mini = "o1-mini"
10
  o1_preview = "o1-preview"
11
+
12
+ class EntityData(BaseModel):
13
+ ageMin: int | None = None
14
+ ageMax: int | None = None
15
+ treatmentAreas: str | None = None
16
+ treatmentMethods: str | None = None
trauma/api/chat/model.py CHANGED
@@ -2,14 +2,13 @@ from datetime import datetime
2
 
3
  from pydantic import Field
4
 
5
- from trauma.api.account.model import AccountModel
6
- from trauma.api.chat.dto import ModelType
7
  from trauma.core.database import MongoBaseModel
8
 
9
 
10
  class ChatModel(MongoBaseModel):
11
  title: str = 'New Chat'
12
  model: ModelType
13
- account: AccountModel | None = None
14
  datetimeInserted: datetime = Field(default_factory=datetime.now)
15
  datetimeUpdated: datetime = Field(default_factory=datetime.now)
 
2
 
3
  from pydantic import Field
4
 
5
+ from trauma.api.chat.dto import ModelType, EntityData
 
6
  from trauma.core.database import MongoBaseModel
7
 
8
 
9
  class ChatModel(MongoBaseModel):
10
  title: str = 'New Chat'
11
  model: ModelType
12
+ entityData: EntityData | None = None
13
  datetimeInserted: datetime = Field(default_factory=datetime.now)
14
  datetimeUpdated: datetime = Field(default_factory=datetime.now)
trauma/api/chat/views.py CHANGED
@@ -1,9 +1,7 @@
1
  from typing import Optional
2
 
3
  from fastapi import Query
4
- from fastapi.params import Depends
5
 
6
- from trauma.api.account.model import AccountModel
7
  from trauma.api.chat import chat_router
8
  from trauma.api.chat.db_requests import (get_chat_obj,
9
  create_chat_obj,
@@ -12,7 +10,6 @@ from trauma.api.chat.db_requests import (get_chat_obj,
12
  get_all_chats_obj)
13
  from trauma.api.chat.schemas import ChatWrapper, AllChatWrapper, CreateChatRequest, AllChatResponse, ChatTitleRequest
14
  from trauma.api.common.dto import Paging
15
- from trauma.core.security import PermissionDependency
16
  from trauma.core.wrappers import TraumaResponseWrapper
17
 
18
 
@@ -20,9 +17,8 @@ from trauma.core.wrappers import TraumaResponseWrapper
20
  async def get_all_chats(
21
  pageSize: Optional[int] = Query(10, description="Number of objects to return per page"),
22
  pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
23
- account: AccountModel = Depends(PermissionDependency())
24
  ) -> AllChatWrapper:
25
- chats, total_count = await get_all_chats_obj(pageSize, pageIndex, account)
26
  response = AllChatResponse(
27
  paging=Paging(pageSize=pageSize, pageIndex=pageIndex, totalCount=total_count),
28
  data=chats
@@ -32,29 +28,29 @@ async def get_all_chats(
32
 
33
  @chat_router.get('/{chatId}')
34
  async def get_chat(
35
- chatId: str, account: AccountModel = Depends(PermissionDependency(is_public=True))
36
  ) -> ChatWrapper:
37
- chat = await get_chat_obj(chatId, account)
38
  return ChatWrapper(data=chat)
39
 
40
 
41
  @chat_router.post('')
42
  async def create_chat(
43
- chat_data: CreateChatRequest, account: AccountModel = Depends(PermissionDependency(is_public=True))
44
  ) -> ChatWrapper:
45
- chat = await create_chat_obj(chat_data, account)
46
  return ChatWrapper(data=chat)
47
 
48
 
49
  @chat_router.delete('/{chatId}')
50
- async def delete_chat(chatId: str, account: AccountModel = Depends(PermissionDependency())) -> TraumaResponseWrapper:
51
- await delete_chat_obj(chatId, account)
52
  return TraumaResponseWrapper()
53
 
54
 
55
  @chat_router.patch('/{chatId}/title')
56
  async def update_chat_title(
57
- chatId: str, chat: ChatTitleRequest, account: AccountModel = Depends(PermissionDependency())
58
  ) -> ChatWrapper:
59
- chat = await update_chat_obj_title(chatId, chat, account)
60
  return ChatWrapper(data=chat)
 
1
  from typing import Optional
2
 
3
  from fastapi import Query
 
4
 
 
5
  from trauma.api.chat import chat_router
6
  from trauma.api.chat.db_requests import (get_chat_obj,
7
  create_chat_obj,
 
10
  get_all_chats_obj)
11
  from trauma.api.chat.schemas import ChatWrapper, AllChatWrapper, CreateChatRequest, AllChatResponse, ChatTitleRequest
12
  from trauma.api.common.dto import Paging
 
13
  from trauma.core.wrappers import TraumaResponseWrapper
14
 
15
 
 
17
  async def get_all_chats(
18
  pageSize: Optional[int] = Query(10, description="Number of objects to return per page"),
19
  pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
 
20
  ) -> AllChatWrapper:
21
+ chats, total_count = await get_all_chats_obj(pageSize, pageIndex)
22
  response = AllChatResponse(
23
  paging=Paging(pageSize=pageSize, pageIndex=pageIndex, totalCount=total_count),
24
  data=chats
 
28
 
29
  @chat_router.get('/{chatId}')
30
  async def get_chat(
31
+ chatId: str
32
  ) -> ChatWrapper:
33
+ chat = await get_chat_obj(chatId)
34
  return ChatWrapper(data=chat)
35
 
36
 
37
  @chat_router.post('')
38
  async def create_chat(
39
+ chat_data: CreateChatRequest
40
  ) -> ChatWrapper:
41
+ chat = await create_chat_obj(chat_data)
42
  return ChatWrapper(data=chat)
43
 
44
 
45
  @chat_router.delete('/{chatId}')
46
+ async def delete_chat(chatId: str) -> TraumaResponseWrapper:
47
+ await delete_chat_obj(chatId)
48
  return TraumaResponseWrapper()
49
 
50
 
51
  @chat_router.patch('/{chatId}/title')
52
  async def update_chat_title(
53
+ chatId: str, chat: ChatTitleRequest
54
  ) -> ChatWrapper:
55
+ chat = await update_chat_obj_title(chatId, chat)
56
  return ChatWrapper(data=chat)
trauma/api/data/dto.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class AgeGroup(BaseModel):
5
+ ageMin: int
6
+ ageMax: int
7
+
8
+ class ContactDetails(BaseModel):
9
+ phone: str | None = None
10
+ email: str | None = None
11
+ website: str | None = None
12
+ address: str | None = None
13
+ postalCode: str | None = None
trauma/api/data/model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+
3
+ from trauma.api.data.dto import AgeGroup, ContactDetails
4
+ from trauma.core.database import MongoBaseModel
5
+
6
+
7
+ class DataModel(MongoBaseModel):
8
+ startDate: datetime.datetime
9
+ endDate: datetime.datetime
10
+ email: str | None = None
11
+ name: str | None = None
12
+ datetimeUpdated: datetime.datetime | None = None
13
+ organizationName: str | None = None
14
+ organizationLocation: str
15
+ youthCareRegion: str
16
+ postalCode: str
17
+ hasOrganizationMultipleLocations: bool
18
+ publicEmail: str
19
+ website: str
20
+ isTraumaTreatmentVisible: bool
21
+ injuryTreatmentForms: list[str]
22
+ offers: list[str]
23
+ intensities: list[str]
24
+ injuryTreatmentDuration: list[str]
25
+ ageGroups: list[AgeGroup]
26
+ traumaIndication: list[str]
27
+ supplier: str
28
+ treatmentFinancing: list[str]
29
+ isContractedCare: list[bool]
30
+ traumaTeamCharacteristics: list[str]
31
+ traumaDisciplines: list[str]
32
+ practitionersQualifications: list[str]
33
+ AQAs: list[str]
34
+ additions: list[str]
35
+
36
+
37
+ class EntityModel(MongoBaseModel):
38
+ id: int
39
+ name: str
40
+ ageGroups: list[AgeGroup]
41
+ treatmentAreas: list[str]
42
+ treatmentMethods: list[str]
43
+ contactDetails: ContactDetails
trauma/api/data/prepare_data.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import faiss
4
+ import numpy as np
5
+
6
+ from trauma.api.data.model import EntityModel
7
+ from trauma.api.message.ai.openai_request import convert_value_to_embeddings
8
+ # import re
9
+ #
10
+ # import pandas as pd
11
+ #
12
+ # from trauma.api.data.dto import AgeGroup
13
+ # from trauma.api.data.model import EntityModel
14
+ from trauma.core.config import settings
15
+
16
+
17
+ #
18
+ # file_path = 'shorted_data.csv'
19
+ # df = pd.read_csv(file_path)
20
+ #
21
+ #
22
+ # async def main():
23
+ # for _, row in df.iterrows():
24
+ # row = row.to_dict()
25
+ # age_groups_str = row['Age Groups'].split(';')
26
+ # age_groups = []
27
+ # for age_group in age_groups_str:
28
+ # match_ = re.search(r'(\d+)-(\d+)', age_group)
29
+ # if match_:
30
+ # min_age, max_age = match_.groups()
31
+ # age_groups.append(AgeGroup(ageMin=min_age, ageMax=max_age))
32
+ #
33
+ # treatment_areas = row['Youths treatment'].split(';')
34
+ # treatment_areas = list(filter(lambda x: len(x) > 2, [i.strip().strip('\n').strip() for i in treatment_areas]))
35
+ #
36
+ # treatment_methods = row['Offers'].split(';')
37
+ # treatment_methods = list(
38
+ # filter(lambda x: len(x) > 2, [i.strip().strip('\n').strip() for i in treatment_methods]))
39
+ #
40
+ # email = re.search(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", row['Email'])
41
+ # if email:
42
+ # email = email.group(0)
43
+ # else:
44
+ # email = None
45
+ # website = row['Website']
46
+ # if website and not isinstance(website, float):
47
+ # website = re.sub(r'\n+', '', website)
48
+ # website = website.strip().strip('\n').strip()
49
+ # if not website.startswith(('https://', "http://")):
50
+ # website = 'https://' + website.lower()
51
+ # else:
52
+ # website = None
53
+ # model_data = {
54
+ # "id": row['ID'],
55
+ # "name": row['Organization'].strip().strip('\n').strip(),
56
+ # "ageGroups": age_groups,
57
+ # "treatmentAreas": treatment_areas,
58
+ # "treatmentMethods": treatment_methods,
59
+ # "contactDetails": {
60
+ # "phone": None,
61
+ # "email": email,
62
+ # "website": website,
63
+ # "address": row['Location'].strip().strip('\n').strip(),
64
+ # "postalCode": row['Postal code'].strip().strip('\n').strip(),
65
+ # }
66
+ # }
67
+ # entity_model = EntityModel.from_mongo(model_data)
68
+ # await settings.DB_CLIENT.entities.insert_one(entity_model.to_mongo())
69
+
70
+ def prepare_entities_str(entities: list[EntityModel]) -> list[str]:
71
+ entities_str = []
72
+ for entity in entities:
73
+ age_groups_str = ', '.join([f"{age_group.ageMin} - {age_group.ageMax}" for age_group in entity.ageGroups])
74
+ treatment_areas_str = ', '.join(entity.treatmentAreas)
75
+ treatment_methods_str = ', '.join(entity.treatmentMethods)
76
+ entity_str = (f"Company name: {entity.name}. We focus on people in the following age groups: {age_groups_str}. "
77
+ f"We specialize in treating the following disorder: {treatment_areas_str}."
78
+ f" For this, we use the following treatment methods: {treatment_methods_str}.")
79
+ entities_str.append(entity_str)
80
+ return entities_str
81
+
82
+ if __name__ == '__main__':
83
+ asyncio.run(add_index_to_documents())
trauma/api/message/ai/engine.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ from trauma.api.chat.db_requests import update_entity_data_obj
4
+ from trauma.api.chat.model import ChatModel
5
+ from trauma.api.message.ai.openai_request import update_entity_data_with_ai, generate_next_question, \
6
+ generate_search_request, generate_final_response
7
+ from trauma.api.message.db_requests import save_assistant_user_message, filter_entities_by_age, search_semantic_entities
8
+ from trauma.api.message.schemas import CreateMessageResponse
9
+ from trauma.api.message.utils import retrieve_empty_field_from_entity_data, prepare_user_messages_str, \
10
+ prepare_final_entities_str
11
+
12
+
13
+ async def search_entities(
14
+ user_message: str, messages: list[dict], chat: ChatModel
15
+ ) -> CreateMessageResponse:
16
+ entity_data = await update_entity_data_with_ai(chat.entityData, user_message)
17
+ asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
18
+
19
+ empty_field = retrieve_empty_field_from_entity_data(entity_data)
20
+ final_entities = None
21
+
22
+ if empty_field:
23
+ response = await generate_next_question(empty_field, user_message, messages)
24
+ else:
25
+ user_messages_str = prepare_user_messages_str(user_message, messages)
26
+ possible_entity_indexes, search_request = await asyncio.gather(
27
+ filter_entities_by_age(entity_data),
28
+ generate_search_request(user_messages_str, entity_data)
29
+ )
30
+ final_entities = await search_semantic_entities(search_request, possible_entity_indexes)
31
+ final_entities_str = prepare_final_entities_str(final_entities)
32
+ response = await generate_final_response(final_entities_str, user_message, messages)
33
+
34
+ asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
35
+ return CreateMessageResponse(text=response, entities=final_entities)
trauma/api/message/ai/openai_request.py CHANGED
@@ -1,47 +1,73 @@
1
- import base64
2
- import io
3
-
4
- from trauma.api.chat.model import ChatModel
5
- from trauma.api.message.ai.prompts import Prompts
6
- from trauma.api.message.dto import Author
7
- from trauma.api.message.model import MessageModel
8
  from trauma.core.config import settings
9
- from trauma.core.wrappers import TraumaResponseWrapper
10
-
11
-
12
- async def prepare_content(user_message: MessageModel) -> list | str:
13
- if user_message.fileUrl is None:
14
- return user_message.text
15
- else:
16
- path = str(settings.BASE_DIR) + user_message.fileUrl.replace(settings.Issuer, '')
17
- file = await settings.OPENAI_CLIENT.files.create(
18
- file=open(path, 'rb'),
19
- purpose='vision'
20
- )
21
- return [{"type": "image_file", "image_file": {"file_id": file.id, "detail": "low"}}]
22
-
23
-
24
- async def response_generator(chat: ChatModel, user_message: MessageModel):
25
- content = await prepare_content(user_message)
26
- await settings.OPENAI_CLIENT.beta.threads.messages.create(
27
- thread_id=chat.threadId,
28
- role=Author.User.value,
29
- content=content
30
- )
 
 
 
 
 
 
 
 
 
31
 
32
- full_response = ''
33
-
34
- async with settings.OPENAI_CLIENT.beta.threads.runs.create_and_stream(
35
- thread_id=chat.threadId,
36
- assistant_id=settings.ASSISTANT_ID,
37
- instructions=Prompts.generate_response if not user_message.fileUrl else Prompts.generate_response_image,
38
- model=chat.model.value
39
- ) as stream:
40
- async for chunk in stream.text_deltas:
41
- if chunk:
42
- full_response += chunk
43
- mini_data = {"text": chunk}
44
- yield f"data: {TraumaResponseWrapper(data=mini_data).model_dump_json()}\n\n"
45
-
46
- message_obj = MessageModel(chatId=chat.id, author=Author.Assistant, text=full_response)
47
- await settings.DB_CLIENT.messages.insert_one(message_obj.to_mongo())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from trauma.api.chat.dto import EntityData
2
+ from trauma.api.message.ai.prompts import TraumaPrompts
 
 
 
 
 
3
  from trauma.core.config import settings
4
+ from trauma.core.wrappers import openai_wrapper
5
+
6
+
7
+ @openai_wrapper(is_json=True)
8
+ async def update_entity_data_with_ai(entity_data: EntityData, user_message: str):
9
+ messages = [
10
+ {
11
+ "role": "system",
12
+ "content": TraumaPrompts.update_entity_data_with_ai
13
+ .replace("{entity_data}", entity_data.model_dump_json())
14
+ .replace("{user_message}", user_message)
15
+ }
16
+ ]
17
+ return messages
18
+
19
+
20
+ @openai_wrapper(temperature=0.8)
21
+ async def generate_next_question(empty_field: str, user_message: str, message_history: list[dict]):
22
+ messages = [
23
+ {
24
+ "role": "system",
25
+ "content": TraumaPrompts.generate_next_question
26
+ .replace("{empty_field}", empty_field)
27
+ },
28
+ *message_history,
29
+ {
30
+ "role": "user",
31
+ "content": user_message
32
+ }
33
+ ]
34
+ return messages
35
 
36
+
37
+ @openai_wrapper(temperature=0.4)
38
+ async def generate_search_request(user_messages_str: str, entity_data: EntityData):
39
+ messages = [
40
+ {
41
+ "role": "system",
42
+ "content": TraumaPrompts.generate_search_request
43
+ .replace("{entity_data}", entity_data.model_dump_json())
44
+ .replace("{user_messages_str}", user_messages_str)
45
+ }
46
+ ]
47
+ return messages
48
+
49
+
50
+ @openai_wrapper(temperature=0.4)
51
+ async def generate_final_response(final_entities: str, user_message: str, message_history: list[dict]):
52
+ messages = [
53
+ {
54
+ "role": "system",
55
+ "content": TraumaPrompts.generate_recommendation_decision
56
+ .replace("{final_entities}", final_entities)
57
+
58
+ },
59
+ *message_history,
60
+ {
61
+ "role": "user",
62
+ "content": user_message
63
+ }
64
+ ]
65
+ return messages
66
+
67
+ async def convert_value_to_embeddings(value: str) -> list[float]:
68
+ embeddings = await settings.OPENAI_CLIENT.embeddings.create(
69
+ input=value,
70
+ model='text-embedding-3-large',
71
+ dimensions=1536,
72
+ )
73
+ return embeddings.data[0].embedding
trauma/api/message/ai/prompts.py CHANGED
@@ -1,216 +1,5 @@
1
- class Prompts:
2
- generate_response = """## Objective
3
-
4
- You are Hector, the Mitutoyo virtual assistant, engineered for precision in the measurement industry. Your core function is two-fold: first, to match users with ideal products based on their specific requirements, and second, to intelligently recommend compatible accessories and complementary products for upselling opportunities.
5
-
6
- ## Context
7
-
8
- Mitutoyo's B2B platform encompasses precision measurement tools and accessories. Your knowledge base consists of structured JSON product data specifically focused on Mitutoyo's digital and analogue micrometers, including their detailed specifications, features, and accessory relationships. While you have broad knowledge of Mitutoyo's full product range, your detailed product data and recommendation capabilities are currently optimized for the digital and analogue micrometer categories. Each recommendation must be based on exact matches between user requirements and product specifications.
9
-
10
- ## Data Processing Protocol
11
-
12
- 1. **Primary Product Matching**:
13
- * Parse user requirements against product specifications
14
- * Match technical requirements (e.g., "carbide tipped jaws") to FEATURE elements
15
- * Validate matches using PRODUCT_DETAILS specifications
16
- * Confirm accuracy through DESCRIPTION_LONG technical details
17
- 2. **Product Data Extraction**:
18
- * SUPPLIER_PID (for exact product identification)
19
- * DESCRIPTION_SHORT (en/nl) for product names
20
- * DESCRIPTION_LONG (en/nl) for technical details
21
- * PRODUCT_DETAILS for specifications
22
- * FEATURE elements for specific capabilities
23
- * PRODUCT_ORDER_DETAILS for availability
24
- * PRODUCT_PRICE_DETAILS for pricing
25
- 3. **Accessory Matching Protocol**:
26
- * Extract all PRODUCT_REFERENCE entries
27
- * Validate PROD_ID_TO compatibility
28
- * Cross-reference accessory specifications
29
- * Verify physical compatibility parameters
30
-
31
- ## Interaction Flow
32
-
33
- 1. **Requirement Analysis**:
34
- * Parse user's technical requirements
35
- * Identify specific feature requests
36
- * Match requirements to product specifications
37
- * Validate technical compatibility
38
- 2. **Primary Product Recommendation**:
39
- * Present matching products with complete details:
40
- * Product name (both languages)
41
- * Article number
42
- * Key specifications
43
- * Relevant features
44
- * Price information
45
- 3. **Strategic Upselling**:
46
- * Analyze PRODUCT_REFERENCE data
47
- * Identify value-adding accessories
48
- * Present complementary products
49
- * Explain benefits and compatibility
50
- 4. **Verification Process**:
51
- * Double-check all technical matches
52
- * Verify compatibility of all recommendations
53
- * Confirm pricing and availability
54
- * Validate all product relationships
55
-
56
- ## Response Format Requirements
57
-
58
- ```markdown
59
- ## Primary Recommendation
60
-
61
- - Product: [Name EN/NL]
62
- - Article: [SUPPLIER_PID]
63
- - Key Features: [Matched Requirements]
64
- - Price: [From PRODUCT_PRICE_DETAILS]
65
-
66
- ## Recommended Accessories
67
-
68
- 1. [Primary Accessory]
69
- - Purpose: [Specific Benefit]
70
- - Article: [SUPPLIER_PID]
71
- - Compatibility: [Verification Details]
72
- 2. [Secondary Accessories]
73
- - Purpose: [Specific Benefit]
74
- - Article: [SUPPLIER_PID]
75
- - Compatibility: [Verification Details]
76
- ```
77
-
78
- ## Error Prevention Protocol
79
-
80
- 1. **Technical Matching**:
81
- * Verify exact feature matches
82
- * Confirm dimensional compatibility
83
- * Validate technical specifications
84
- * Cross-reference all requirements
85
- 2. **Compatibility Verification**:
86
- * Check PRODUCT_REFERENCE links
87
- * Verify physical specifications
88
- * Confirm accessory compatibility
89
- * Validate system requirements
90
- 3. **Data Accuracy**:
91
- * Double-check all article numbers
92
- * Verify price information
93
- * Confirm availability status
94
- * Validate technical specifications
95
-
96
- ## Prohibitions
97
-
98
- * No assumptions about compatibility
99
- * No recommendations without PRODUCT_REFERENCE validation
100
- * No incomplete technical specifications
101
- * No unverified product relationships
102
- * NO discussion of competitor products or brands under any circumstances
103
- * NO comparative analysis with other manufacturers
104
- * NO recommendations outside of Mitutoyo's product range
105
- * NO redirecting to other brands, even if Mitutoyo doesn't offer a solution
106
- * NO market comparisons or industry benchmarking against other brands
107
-
108
- ## Example Interaction
109
-
110
- User: "Need a micrometer with carbide tipped jaws"
111
- Response Protocol:
112
- 1. Match "carbide tipped jaws" with FEATURE elements
113
- 2. Verify products meeting specification
114
- 3. Extract complete product details
115
- 4. Identify compatible accessories through PRODUCT_REFERENCE
116
- 5. Present primary recommendation with upselling options
117
- 6. Verify all technical relationships
118
-
119
- ## Key Performance Requirements
120
-
121
- * 100% accuracy in technical matching
122
- * Complete verification of all recommendations
123
- * Precise accessory compatibility checking
124
- * Clear, structured response format
125
- * Professional, technical communication style"""
126
- generate_response_image = """### **Prompt Purpose**
127
-
128
- You are Hector, Mitutoyo's visual recognition expert, designed to identify Mitutoyo precision measurement instruments from images with unmatched precision and adherence to Mitutoyo’s high standards.
129
-
130
- ### **Core Functionality**
131
- - **Primary Role:** Analyze images to confidently identify Mitutoyo products.
132
- - **Communication:** Clearly indicate the confidence level and provide actionable feedback for incomplete identification cases.
133
-
134
- ### **Recognition Protocol**
135
-
136
- #### **Visual Analysis Sequence**
137
-
138
- 1. Confirm Mitutoyo product authenticity.
139
- 2. Identify the product category (e.g., micrometer, caliper, indicator).
140
- 3. Recognize specific features and characteristics.
141
- 4. Match visual data against known Mitutoyo design elements.
142
- 5. Determine confidence level based on recognition certainty.
143
-
144
- #### **Confidence Level Classification**
145
-
146
- - **Level 1: High Confidence (100% Certain)**
147
- - **Product Identification:**
148
- - Confirmed as Mitutoyo [Product Name].
149
- - **Article Number:** [SUPPLIER_PID].
150
- - **Product Details:**
151
- - Include complete specifications based on available data.
152
-
153
- - **Level 2: Medium Confidence (Visual Recognition)**
154
- - **Product Identification:**
155
- - Recognized as Mitutoyo [Product Category/Series].
156
- - Requires additional training data to confirm the exact article number.
157
- - **Identified Features:**
158
- - List identifiable features and series characteristics.
159
-
160
- - **Level 3: Limited Confidence**
161
- - **Initial Recognition:**
162
- - Confirmed as a Mitutoyo [Product Category].
163
- - Requires additional information for precise identification.
164
- - **Recommendations:**
165
- 1. Share the article number if available.
166
- 2. Provide details about specific requirements.
167
- 3. Include additional product visuals or specifications.
168
-
169
- ### **Absolute Prohibitions**
170
- - **No guessing:** Avoid speculation on article numbers.
171
- - **Non-Mitutoyo products:** Do not identify or compare with competitor products.
172
- - **Uncertainty:** Avoid assumptions about specifications without solid evidence.
173
- - **Precision:** Only communicate confirmed and accurate information.
174
-
175
- ### **Communication Guidelines**
176
-
177
- #### **Certainty Communication**
178
-
179
- - Clearly express confidence levels.
180
- - Explicitly state identification limitations.
181
- - Outline additional information needed for improved accuracy.
182
-
183
- #### **Training Transparency**
184
-
185
- - Acknowledge when training data is insufficient.
186
- - Professionally explain limitations.
187
- - Offer constructive steps for better identification.
188
-
189
- #### **Response Structure**
190
-
191
- 1. Start with the confidence level.
192
- 2. Provide available identification details.
193
- 3. Highlight limitations and missing data.
194
- 4. Suggest next steps or provide recommendations.
195
-
196
- #### **Unclear Cases Protocol**
197
-
198
- - Confirm receipt of the image.
199
- - State if the product is Mitutoyo.
200
- - Specify the confidence level.
201
- - Detail limitations and request additional details.
202
- - Propose alternative assistance options if applicable.
203
-
204
- ### **Quality Assurance**
205
-
206
- - Provide article numbers only when 100% certain.
207
- - Validate all visual markers against Mitutoyo's known patterns.
208
- - Clearly communicate confidence levels.
209
- - Suggest actionable steps for uncertain cases.
210
-
211
- ### **Key Performance Metrics**
212
-
213
- - **Accuracy:** Ensure precise product category identification.
214
- - **Clarity:** Effectively communicate certainty levels and next steps.
215
- - **Professionalism:** Handle limitations constructively.
216
- - **Adherence:** Uphold Mitutoyo’s precision standards at all times."""
 
1
+ class TraumaPrompts:
2
+ update_entity_data_with_ai = """"""
3
+ generate_next_question = """"""
4
+ generate_search_request = """"""
5
+ generate_recommendation_decision = """"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trauma/api/message/ai/utils.py DELETED
@@ -1,30 +0,0 @@
1
- from trauma.api.message.ai.prompts import Prompts
2
- from trauma.api.message.model import MessageModel
3
-
4
-
5
- def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
6
- openai_messages = [{"role": "system", "content": Prompts.generate_response}]
7
- for message in messages:
8
-
9
- if message.file:
10
- content = [
11
- {
12
- "type": "text", "text": message.text
13
- },
14
- {
15
- "type": "image_url",
16
- "image_url": {
17
- "url": f"data:image/jpeg;base64,{message.file.base64String}",
18
- "detail": "low"
19
- }
20
- },
21
- ]
22
- else:
23
- content = message.text
24
-
25
- openai_messages.append({
26
- "role": message.author.value,
27
- "content": content
28
- })
29
-
30
- return openai_messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trauma/api/message/db_requests.py CHANGED
@@ -1,18 +1,34 @@
 
 
1
  import asyncio
2
 
3
  from fastapi import HTTPException
4
 
5
- from trauma.api.account.model import AccountModel
6
  from trauma.api.chat.model import ChatModel
 
 
7
  from trauma.api.message.dto import Author
8
  from trauma.api.message.model import MessageModel
9
  from trauma.api.message.schemas import CreateMessageRequest
10
  from trauma.core.config import settings
 
11
 
12
 
13
- async def get_all_chat_messages_obj(
14
- chat_id: str, account: AccountModel
15
- ) -> list[MessageModel]:
 
 
 
 
 
 
 
 
 
 
 
16
  messages, chat = await asyncio.gather(
17
  settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None),
18
  settings.DB_CLIENT.chats.find_one({"id": chat_id})
@@ -23,23 +39,40 @@ async def get_all_chat_messages_obj(
23
  raise HTTPException(status_code=404, detail="Chat not found")
24
 
25
  chat = ChatModel.from_mongo(chat)
26
- if account and chat.account != account:
27
- raise HTTPException(status_code=403, detail="Chat account not match")
28
 
29
- return messages
30
 
31
- async def create_message_obj(
32
- chat_id: str, message_data: CreateMessageRequest, account: AccountModel
33
- ) -> tuple[MessageModel, ChatModel]:
34
- chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
35
- if not chat:
36
- raise HTTPException(status_code=404, detail="Chat not found")
37
 
38
- chat = ChatModel.from_mongo(chat)
39
- if account and chat.account != account:
40
- raise HTTPException(status_code=403, detail="Chat account not match")
41
 
42
- message = MessageModel(**message_data.model_dump(), chatId=chat_id, author=Author.User)
43
- await settings.DB_CLIENT.messages.insert_one(message.to_mongo())
 
 
 
 
 
 
 
 
 
 
44
 
45
- return message, chat
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
  import asyncio
4
 
5
  from fastapi import HTTPException
6
 
7
+ from trauma.api.chat.dto import EntityData
8
  from trauma.api.chat.model import ChatModel
9
+ from trauma.api.data.model import EntityModel
10
+ from trauma.api.message.ai.openai_request import convert_value_to_embeddings
11
  from trauma.api.message.dto import Author
12
  from trauma.api.message.model import MessageModel
13
  from trauma.api.message.schemas import CreateMessageRequest
14
  from trauma.core.config import settings
15
+ from trauma.core.wrappers import background_task
16
 
17
 
18
+ async def create_message_obj(
19
+ chat_id: str, message_data: CreateMessageRequest
20
+ ) -> tuple[MessageModel, ChatModel]:
21
+ chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id})
22
+ if not chat:
23
+ raise HTTPException(status_code=404, detail="Chat not found.")
24
+
25
+ message = MessageModel(**message_data.model_dump(), chatId=chat_id, author=Author.User)
26
+ await settings.DB_CLIENT.messages.insert_one(message.to_mongo())
27
+
28
+ return message, chat
29
+
30
+
31
+ async def get_all_chat_messages_obj(chat_id: str) -> tuple[list[MessageModel], ChatModel]:
32
  messages, chat = await asyncio.gather(
33
  settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None),
34
  settings.DB_CLIENT.chats.find_one({"id": chat_id})
 
39
  raise HTTPException(status_code=404, detail="Chat not found")
40
 
41
  chat = ChatModel.from_mongo(chat)
42
+ return messages, chat
 
43
 
 
44
 
45
+ @background_task()
46
+ async def save_assistant_user_message(user_message: str, assistant_message: str, chat_id: str) -> None:
47
+ user_message = MessageModel(chatId=chat_id, author=Author.User, text=user_message)
48
+ assistant_message = MessageModel(chatId=chat_id, author=Author.User, text=assistant_message)
49
+ await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
50
+ await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
51
 
 
 
 
52
 
53
+ async def filter_entities_by_age(entity: EntityData) -> list[int]:
54
+ query = {
55
+ "ageGroups": {
56
+ "$elemMatch": {
57
+ "ageMin": {"$lte": entity.ageMax},
58
+ "ageMax": {"$gte": entity.ageMin}
59
+ }
60
+ }
61
+ }
62
+ entities = await settings.DB_CLIENT.entities.find(query).to_list(length=None)
63
+ return [entity['index'] for entity in entities]
64
+
65
 
66
+ async def search_semantic_entities(search_request: str, entities_indexes: list[int]):
67
+ embedding = await convert_value_to_embeddings(search_request)
68
+ query_embedding = np.array([embedding], dtype=np.float32)
69
+ distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal)
70
+ distances = distances[0]
71
+ indices = indices[0]
72
+ filtered_results = [
73
+ {"index": int(idx), "distance": float(dist)}
74
+ for idx, dist in zip(indices, distances)
75
+ if idx in entities_indexes and dist <= 0.6
76
+ ]
77
+ filtered_results = sorted(filtered_results, key=lambda x: x["distance"])
78
+ return filtered_results[:5]
trauma/api/message/model.py CHANGED
@@ -2,7 +2,7 @@ from datetime import datetime
2
 
3
  from pydantic import Field
4
 
5
- from trauma.api.message.dto import Author, File
6
  from trauma.core.database import MongoBaseModel
7
 
8
 
@@ -10,6 +10,5 @@ class MessageModel(MongoBaseModel):
10
  chatId: str
11
  author: Author
12
  text: str
13
- fileUrl: str | None = None
14
  datetimeInserted: datetime = Field(default_factory=datetime.now)
15
  datetimeUpdated: datetime = Field(default_factory=datetime.now)
 
2
 
3
  from pydantic import Field
4
 
5
+ from trauma.api.message.dto import Author
6
  from trauma.core.database import MongoBaseModel
7
 
8
 
 
10
  chatId: str
11
  author: Author
12
  text: str
 
13
  datetimeInserted: datetime = Field(default_factory=datetime.now)
14
  datetimeUpdated: datetime = Field(default_factory=datetime.now)
trauma/api/message/schemas.py CHANGED
@@ -1,6 +1,7 @@
1
  from pydantic import BaseModel
2
 
3
  from trauma.api.common.dto import Paging
 
4
  from trauma.api.message.model import MessageModel
5
  from trauma.core.wrappers import TraumaResponseWrapper
6
 
@@ -21,3 +22,8 @@ class AllMessageResponse(BaseModel):
21
 
22
  class AllMessageWrapper(TraumaResponseWrapper[AllMessageResponse]):
23
  pass
 
 
 
 
 
 
1
  from pydantic import BaseModel
2
 
3
  from trauma.api.common.dto import Paging
4
+ from trauma.api.data.model import EntityModel
5
  from trauma.api.message.model import MessageModel
6
  from trauma.core.wrappers import TraumaResponseWrapper
7
 
 
22
 
23
  class AllMessageWrapper(TraumaResponseWrapper[AllMessageResponse]):
24
  pass
25
+
26
+
27
+ class CreateMessageResponse(BaseModel):
28
+ text: str
29
+ entities: list[EntityModel] | None = None
trauma/api/message/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from trauma.api.data.model import EntityModel
4
+ from trauma.api.message.model import MessageModel
5
+
6
+
7
+ def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
8
+ openai_messages = []
9
+ for message in messages:
10
+ content = message.text
11
+ openai_messages.append({
12
+ "role": message.author.value,
13
+ "content": content
14
+ })
15
+
16
+ return openai_messages
17
+
18
+
19
+ def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
20
+ for k, v in entity_data.items():
21
+ if not v:
22
+ return k
23
+ return None
24
+
25
+
26
+ def prepare_user_messages_str(user_message: str, messages: list[dict]) -> str:
27
+ user_message_str = ''
28
+ for message in messages:
29
+ if message['role'] == 'user':
30
+ user_message_str += f'- {message['content']}\n'
31
+ user_message_str += f'- {user_message}'
32
+ return user_message_str
33
+
34
+
35
+ def prepare_final_entities_str(entities: list[EntityModel]) -> str:
36
+ entities_list = []
37
+ for entity in entities:
38
+ entities_list.append(entity.model_dump(mode='json', exclude={'id', 'contactDetails'}))
39
+ return json.dumps({"entities": entities_list})
trauma/api/message/views.py CHANGED
@@ -1,22 +1,20 @@
1
- from fastapi.params import Depends
2
- from starlette.responses import StreamingResponse
3
-
4
- from trauma.api.account.model import AccountModel
5
  from trauma.api.common.dto import Paging
6
  from trauma.api.message import message_router
7
- from trauma.api.message.ai.openai_request import response_generator
8
- from trauma.api.message.db_requests import get_all_chat_messages_obj, create_message_obj
9
  from trauma.api.message.schemas import (AllMessageWrapper,
10
  AllMessageResponse,
11
- CreateMessageRequest)
12
- from trauma.core.security import PermissionDependency
 
 
13
 
14
 
15
  @message_router.get('/{chatId}/all')
16
  async def get_all_chat_messages(
17
- chatId: str, account: AccountModel = Depends(PermissionDependency(is_public=True))
18
  ) -> AllMessageWrapper:
19
- messages = await get_all_chat_messages_obj(chatId, account)
20
  response = AllMessageResponse(
21
  paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
22
  data=messages
@@ -28,7 +26,8 @@ async def get_all_chat_messages(
28
  async def create_message(
29
  chatId: str,
30
  message_data: CreateMessageRequest,
31
- account: AccountModel = Depends(PermissionDependency(is_public=True))
32
- ) -> StreamingResponse:
33
- user_message, chat = await create_message_obj(chatId, message_data, account)
34
- return StreamingResponse(response_generator(chat, user_message), media_type='text/event-stream')
 
 
 
 
 
 
1
  from trauma.api.common.dto import Paging
2
  from trauma.api.message import message_router
3
+ from trauma.api.message.ai.engine import search_entities
4
+ from trauma.api.message.db_requests import get_all_chat_messages_obj
5
  from trauma.api.message.schemas import (AllMessageWrapper,
6
  AllMessageResponse,
7
+ CreateMessageRequest,
8
+ CreateMessageResponse)
9
+ from trauma.api.message.utils import transform_messages_to_openai
10
+ from trauma.core.wrappers import TraumaResponseWrapper
11
 
12
 
13
  @message_router.get('/{chatId}/all')
14
  async def get_all_chat_messages(
15
+ chatId: str
16
  ) -> AllMessageWrapper:
17
+ messages, _ = await get_all_chat_messages_obj(chatId)
18
  response = AllMessageResponse(
19
  paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
20
  data=messages
 
26
  async def create_message(
27
  chatId: str,
28
  message_data: CreateMessageRequest,
29
+ ) -> TraumaResponseWrapper[CreateMessageResponse]:
30
+ messages, chat = await get_all_chat_messages_obj(chatId)
31
+ message_history = transform_messages_to_openai(messages)
32
+ response = await search_entities(message_data.text, message_history, chat)
33
+ return TraumaResponseWrapper(data=response)
trauma/api/security/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- from fastapi import APIRouter
2
-
3
- security_router = APIRouter(
4
- prefix='/api/security'
5
- )
6
-
7
- from . import views
 
 
 
 
 
 
 
 
trauma/api/security/db_requests.py DELETED
@@ -1,34 +0,0 @@
1
- import asyncio
2
-
3
- from fastapi import HTTPException
4
-
5
- from trauma.api.account.model import AccountModel
6
- from trauma.api.common.db_requests import check_unique_fields_existence
7
- from trauma.api.security.schemas import RegisterAccountRequest, LoginAccountRequest
8
- from trauma.core.config import settings
9
- from trauma.core.security import verify_password
10
-
11
-
12
- async def save_account(data: RegisterAccountRequest) -> AccountModel:
13
- await asyncio.gather(
14
- check_unique_fields_existence("accounts", "email", data.email),
15
- )
16
- account = AccountModel(
17
- **data.model_dump()
18
- )
19
- await settings.DB_CLIENT.accounts.insert_one(account.to_mongo())
20
- return account
21
-
22
-
23
- async def authenticate_account(data: LoginAccountRequest) -> AccountModel:
24
- account = await settings.DB_CLIENT.accounts.find_one(
25
- {"email": data.email},
26
- collation={"locale": "en", "strength": 2})
27
- if account is None:
28
- raise HTTPException(status_code=404, detail="Invalid email or password.")
29
-
30
- account = AccountModel.from_mongo(account)
31
-
32
- if not verify_password(data.password, account.password):
33
- raise HTTPException(status_code=401, detail="Invalid email or password.")
34
- return account
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trauma/api/security/schemas.py DELETED
@@ -1,28 +0,0 @@
1
- from pydantic import BaseModel, EmailStr
2
-
3
- from trauma.api.account.dto import AccessToken
4
- from trauma.api.account.model import AccountModel
5
- from trauma.core.wrappers import TraumaResponseWrapper
6
-
7
-
8
- class RegisterAccountRequest(BaseModel):
9
- email: EmailStr
10
- password: str
11
-
12
-
13
- class RegisterAccountWrapper(TraumaResponseWrapper[AccountModel]):
14
- pass
15
-
16
-
17
- class LoginAccountRequest(BaseModel):
18
- email: EmailStr
19
- password: str
20
-
21
-
22
- class LoginAccountResponse(BaseModel):
23
- accessToken: AccessToken
24
- account: AccountModel
25
-
26
-
27
- class LoginAccountWrapper(TraumaResponseWrapper[LoginAccountResponse]):
28
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trauma/api/security/views.py DELETED
@@ -1,27 +0,0 @@
1
- from typing import Literal
2
-
3
- from trauma.api.account.dto import AccessToken
4
- from trauma.api.security import security_router
5
- from trauma.api.security.db_requests import authenticate_account, save_account
6
- from trauma.api.security.schemas import RegisterAccountRequest, RegisterAccountWrapper, LoginAccountResponse, \
7
- LoginAccountWrapper, LoginAccountRequest
8
- from trauma.core.security import create_access_token
9
-
10
- model: Literal["accounts"] = "accounts"
11
-
12
-
13
- @security_router.post('/register')
14
- async def register_user(data: RegisterAccountRequest) -> RegisterAccountWrapper:
15
- account = await save_account(data)
16
- return RegisterAccountWrapper(data=account)
17
-
18
-
19
- @security_router.post('/login')
20
- async def login(data: LoginAccountRequest) -> LoginAccountWrapper:
21
- account = await authenticate_account(data)
22
- access_token = create_access_token(account.email, str(account.id))
23
- response = LoginAccountResponse(
24
- accessToken=AccessToken(value=access_token),
25
- account=account,
26
- )
27
- return LoginAccountWrapper(data=response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trauma/core/config.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import pathlib
3
  from functools import lru_cache
4
 
 
5
  import motor.motor_asyncio
6
  from dotenv import load_dotenv
7
  from openai import AsyncClient
@@ -14,7 +15,7 @@ class BaseConfig:
14
  SECRET_KEY = os.getenv('SECRET')
15
  DB_CLIENT = motor.motor_asyncio.AsyncIOMotorClient(os.getenv("MONGO_DB_URL")).trauma
16
  OPENAI_CLIENT = AsyncClient(api_key=os.getenv('OPENAI_API_KEY'))
17
-
18
 
19
  class DevelopmentConfig(BaseConfig):
20
  Issuer = "http://localhost:8000"
 
2
  import pathlib
3
  from functools import lru_cache
4
 
5
+ import faiss
6
  import motor.motor_asyncio
7
  from dotenv import load_dotenv
8
  from openai import AsyncClient
 
15
  SECRET_KEY = os.getenv('SECRET')
16
  DB_CLIENT = motor.motor_asyncio.AsyncIOMotorClient(os.getenv("MONGO_DB_URL")).trauma
17
  OPENAI_CLIENT = AsyncClient(api_key=os.getenv('OPENAI_API_KEY'))
18
+ SEMANTIC_INDEX = faiss.read_index(str(pathlib.Path(__file__).parent.parent.parent / 'indexes' / 'entities.index'))
19
 
20
  class DevelopmentConfig(BaseConfig):
21
  Issuer = "http://localhost:8000"
trauma/core/wrappers.py CHANGED
@@ -1,6 +1,8 @@
 
1
  from functools import wraps
2
  from typing import Generic, Optional, TypeVar
3
 
 
4
  from fastapi import HTTPException
5
  from pydantic import BaseModel
6
  from starlette.responses import JSONResponse
@@ -42,3 +44,42 @@ def exception_wrapper(http_error: int, error_message: str):
42
  return wrapper
43
 
44
  return decorator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  from functools import wraps
3
  from typing import Generic, Optional, TypeVar
4
 
5
+ import pydash
6
  from fastapi import HTTPException
7
  from pydantic import BaseModel
8
  from starlette.responses import JSONResponse
 
44
  return wrapper
45
 
46
  return decorator
47
+
48
+
49
+ def openai_wrapper(
50
+ temperature: int | float = 0, model: str = "gpt-4o-mini", is_json: bool = False, return_: str = None
51
+ ):
52
+ def decorator(func):
53
+ @wraps(func)
54
+ async def wrapper(*args, **kwargs) -> str:
55
+ messages = await func(*args, **kwargs)
56
+ completion = await settings.OPENAI_CLIENT.chat.completions.create(
57
+ messages=messages,
58
+ temperature=temperature,
59
+ n=1,
60
+ model=model,
61
+ response_format={"type": "json_object"} if is_json else {"type": "text"}
62
+ )
63
+ response = completion.choices[0].message.content
64
+ if is_json:
65
+ response = json.loads(response)
66
+ if return_:
67
+ return pydash.get(response, return_)
68
+ return response
69
+
70
+ return wrapper
71
+
72
+ return decorator
73
+
74
+ def background_task():
75
+ def decorator(func):
76
+ @wraps(func)
77
+ async def wrapper(*args, **kwargs) -> str:
78
+ try:
79
+ result = await func(*args, **kwargs)
80
+ return result
81
+ except Exception as e:
82
+ pass
83
+ return wrapper
84
+
85
+ return decorator