brestok commited on
Commit
e754e5a
·
1 Parent(s): 52bd2ca

added highlights

Browse files
trauma/api/data/dto.py CHANGED
@@ -10,4 +10,5 @@ class ContactDetails(BaseModel):
10
  email: str | None = None
11
  website: str | None = None
12
  address: str | None = None
13
- postalCode: str | None = None
 
 
10
  email: str | None = None
11
  website: str | None = None
12
  address: str | None = None
13
+ postalCode: str | None = None
14
+
trauma/api/data/model.py CHANGED
@@ -42,3 +42,9 @@ class EntityModel(MongoBaseModel):
42
  treatmentMethods: list[str]
43
  description: str = ''
44
  contactDetails: ContactDetails
 
 
 
 
 
 
 
42
  treatmentMethods: list[str]
43
  description: str = ''
44
  contactDetails: ContactDetails
45
+
46
+
47
+ class EntityModelExtended(EntityModel):
48
+ highlightedAgeGroup: AgeGroup
49
+ highlightedTreatmentArea: str
50
+ highlightedTreatmentMethod: str
trauma/api/message/ai/engine.py CHANGED
@@ -1,19 +1,24 @@
1
  import asyncio
2
 
 
 
 
3
  from trauma.api.chat.model import ChatModel
 
4
  from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
5
  generate_next_question,
6
  generate_search_request,
7
- generate_final_response)
 
8
  from trauma.api.message.db_requests import (save_assistant_user_message,
9
  filter_entities_by_age,
10
- search_semantic_entities,
11
- update_entity_data_obj)
12
  from trauma.api.message.schemas import CreateMessageResponse
13
  from trauma.api.message.utils import (retrieve_empty_field_from_entity_data,
14
  prepare_user_messages_str,
15
  prepare_final_entities_str,
16
- pick_empty_field_instructions)
 
17
 
18
 
19
  async def search_entities(
@@ -34,9 +39,44 @@ async def search_entities(
34
  filter_entities_by_age(entity_data),
35
  generate_search_request(user_messages_str, entity_data)
36
  )
37
- final_entities = await search_semantic_entities(search_request, possible_entity_indexes)
38
  final_entities_str = prepare_final_entities_str(final_entities)
39
  response = await generate_final_response(final_entities_str, user_message, messages)
40
 
41
  asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
42
  return CreateMessageResponse(text=response, entities=final_entities)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
 
3
+ import numpy as np
4
+
5
+ from trauma.api.chat.dto import EntityData
6
  from trauma.api.chat.model import ChatModel
7
+ from trauma.api.data.model import EntityModel, EntityModelExtended
8
  from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
9
  generate_next_question,
10
  generate_search_request,
11
+ generate_final_response, convert_value_to_embeddings,
12
+ choose_closest_treatment_method, choose_closest_treatment_area)
13
  from trauma.api.message.db_requests import (save_assistant_user_message,
14
  filter_entities_by_age,
15
+ update_entity_data_obj, get_entity_by_index)
 
16
  from trauma.api.message.schemas import CreateMessageResponse
17
  from trauma.api.message.utils import (retrieve_empty_field_from_entity_data,
18
  prepare_user_messages_str,
19
  prepare_final_entities_str,
20
+ pick_empty_field_instructions, find_matching_age_group)
21
+ from trauma.core.config import settings
22
 
23
 
24
  async def search_entities(
 
39
  filter_entities_by_age(entity_data),
40
  generate_search_request(user_messages_str, entity_data)
41
  )
42
+ final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
43
  final_entities_str = prepare_final_entities_str(final_entities)
44
  response = await generate_final_response(final_entities_str, user_message, messages)
45
 
46
  asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
47
  return CreateMessageResponse(text=response, entities=final_entities)
48
+
49
+
50
+ async def search_semantic_entities(
51
+ search_request: str, entity_data: EntityData, entities_indexes: list[int]
52
+ ) -> list[EntityModelExtended]:
53
+ embedding = await convert_value_to_embeddings(search_request)
54
+ query_embedding = np.array([embedding], dtype=np.float32)
55
+ distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal)
56
+ distances = distances[0]
57
+ indices = indices[0]
58
+ filtered_results = [
59
+ {"index": int(idx), "distance": float(dist)}
60
+ for idx, dist in zip(indices, distances)
61
+ if idx in entities_indexes and dist <= 1.3
62
+ ]
63
+ filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
64
+ final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
65
+ final_entities_extended = await asyncio.gather(
66
+ *[extended_entity_with_highlights(entity, entity_data) for entity in final_entities]
67
+ )
68
+ return final_entities_extended
69
+
70
+
71
+ async def extended_entity_with_highlights(entity: EntityModel, entity_data: dict) -> EntityModelExtended:
72
+ age_group = find_matching_age_group(entity, entity_data)
73
+ treatment_area, treatment_method = await asyncio.gather(
74
+ choose_closest_treatment_area(entity.treatmentAreas, entity_data['treatmentArea']),
75
+ choose_closest_treatment_method(entity.treatmentMethods, entity_data['treatmentMethod'])
76
+ )
77
+ return EntityModelExtended(
78
+ **entity.to_mongo(),
79
+ highlightedAgeGroup=age_group,
80
+ highlightedTreatmentArea=treatment_area,
81
+ highlightedTreatmentMethod=treatment_method
82
+ )
trauma/api/message/ai/openai_request.py CHANGED
@@ -74,4 +74,30 @@ async def convert_value_to_embeddings(value: str) -> list[float]:
74
  model='text-embedding-3-large',
75
  dimensions=1536,
76
  )
77
- return embeddings.data[0].embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  model='text-embedding-3-large',
75
  dimensions=1536,
76
  )
77
+ return embeddings.data[0].embedding
78
+
79
+
80
+ @openai_wrapper(is_json=True, return_='result')
81
+ async def choose_closest_treatment_area(treatment_areas: list[str], treatment_area: str):
82
+ messages = [
83
+ {
84
+ "role": "system",
85
+ "content": TraumaPrompts.choose_closest_treatment_area
86
+ .replace("{treatment_areas}", ", ".join(treatment_areas))
87
+ .replace("{treatment_area}", treatment_area)
88
+ }
89
+ ]
90
+ return messages
91
+
92
+
93
+ @openai_wrapper(is_json=True, return_='result')
94
+ async def choose_closest_treatment_method(treatment_methods: list[str], treatment_method: str):
95
+ messages = [
96
+ {
97
+ "role": "system",
98
+ "content": TraumaPrompts.choose_closest_treatment_method
99
+ .replace("{treatment_methods}", ", ".join(treatment_methods))
100
+ .replace("{treatment_method}", treatment_method)
101
+ }
102
+ ]
103
+ return messages
trauma/api/message/ai/prompts.py CHANGED
@@ -137,3 +137,60 @@ Je bent verplicht om een beschrijving voor een kliniek te genereren op basis van
137
  - De beschrijving moet beknopt en bondig zijn.
138
 
139
  [/INST]"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  - De beschrijving moet beknopt en bondig zijn.
138
 
139
  [/INST]"""
140
+
141
+ choose_closest_treatment_area = """
142
+
143
+ ## Task
144
+
145
+ You must determine the most semantically similar disorder or disease from the list of [treatment areas] to the requested disease [requested treatment area]. The most similar disease should be returned in the [result] field of the JSON.
146
+
147
+ ## Data
148
+
149
+ **treatment areas**:
150
+ ```
151
+ {treatment_areas}
152
+ ```
153
+
154
+ **requested treatment area**:
155
+ ```
156
+ {treatment_area}
157
+ ```
158
+
159
+ ## JSON Response format
160
+
161
+ ```json
162
+ {
163
+ "result": "string"
164
+ }
165
+ ```
166
+
167
+ ## Instructions for filling JSON
168
+
169
+ - [result]: The item from the [treatment areas] list that is most semantically similar to the requested disease. The disease name in the result field must exactly match the name as it appears in the [treatment areas] list."""
170
+ choose_closest_treatment_method = """## Task
171
+
172
+ You must determine the most semantically similar treatment method from the list of [treatment methods] to the requested treatment method [requested treatment method]. The most similar treatment method should be returned in the [result] field of the JSON.
173
+
174
+ ## Data
175
+
176
+ **treatment methods**:
177
+ ```
178
+ {treatment_methods}
179
+ ```
180
+
181
+ **requested treatment method**:
182
+ ```
183
+ {treatment_method}
184
+ ```
185
+
186
+ ## JSON Response format
187
+
188
+ ```json
189
+ {
190
+ "result": "string"
191
+ }
192
+ ```
193
+
194
+ ## Instructions for filling JSON
195
+
196
+ - [result]: The item from the [treatment methods] list that is most semantically similar to the requested treatment method. The treatment method name in the result field must exactly match the name as it appears in the [treatment methods] list."""
trauma/api/message/db_requests.py CHANGED
@@ -1,11 +1,9 @@
1
  import asyncio
2
 
3
- import numpy as np
4
  from fastapi import HTTPException
5
 
6
  from trauma.api.chat.model import ChatModel
7
  from trauma.api.data.model import EntityModel
8
- from trauma.api.message.ai.openai_request import convert_value_to_embeddings
9
  from trauma.api.message.dto import Author
10
  from trauma.api.message.model import MessageModel
11
  from trauma.api.message.schemas import CreateMessageRequest
@@ -58,12 +56,12 @@ async def save_assistant_user_message(user_message: str, assistant_message: str,
58
  await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
59
 
60
 
61
- async def filter_entities_by_age(entity: dict) -> list[int]:
62
  query = {
63
  "ageGroups": {
64
  "$elemMatch": {
65
- "ageMin": {"$lte": entity['ageMax']},
66
- "ageMax": {"$gte": entity['ageMin']}
67
  }
68
  }
69
  }
@@ -74,18 +72,3 @@ async def filter_entities_by_age(entity: dict) -> list[int]:
74
  async def get_entity_by_index(index: int) -> EntityModel:
75
  entity = await settings.DB_CLIENT.entities.find_one({"index": index})
76
  return EntityModel.from_mongo(entity)
77
-
78
- async def search_semantic_entities(search_request: str, entities_indexes: list[int]) -> list[EntityModel]:
79
- embedding = await convert_value_to_embeddings(search_request)
80
- query_embedding = np.array([embedding], dtype=np.float32)
81
- distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal)
82
- distances = distances[0]
83
- indices = indices[0]
84
- filtered_results = [
85
- {"index": int(idx), "distance": float(dist)}
86
- for idx, dist in zip(indices, distances)
87
- if idx in entities_indexes and dist <= 1.3
88
- ]
89
- filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
90
- final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
91
- return final_entities
 
1
  import asyncio
2
 
 
3
  from fastapi import HTTPException
4
 
5
  from trauma.api.chat.model import ChatModel
6
  from trauma.api.data.model import EntityModel
 
7
  from trauma.api.message.dto import Author
8
  from trauma.api.message.model import MessageModel
9
  from trauma.api.message.schemas import CreateMessageRequest
 
56
  await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
57
 
58
 
59
+ async def filter_entities_by_age(entity_data: dict) -> list[int]:
60
  query = {
61
  "ageGroups": {
62
  "$elemMatch": {
63
+ "ageMin": {"$lte": entity_data['ageMax']},
64
+ "ageMax": {"$gte": entity_data['ageMin']}
65
  }
66
  }
67
  }
 
72
  async def get_entity_by_index(index: int) -> EntityModel:
73
  entity = await settings.DB_CLIENT.entities.find_one({"index": index})
74
  return EntityModel.from_mongo(entity)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trauma/api/message/schemas.py CHANGED
@@ -1,7 +1,7 @@
1
  from pydantic import BaseModel
2
 
3
  from trauma.api.common.dto import Paging
4
- from trauma.api.data.model import EntityModel
5
  from trauma.api.message.model import MessageModel
6
  from trauma.core.wrappers import TraumaResponseWrapper
7
 
@@ -24,4 +24,4 @@ class AllMessageWrapper(TraumaResponseWrapper[AllMessageResponse]):
24
 
25
  class CreateMessageResponse(BaseModel):
26
  text: str
27
- entities: list[EntityModel] | None = None
 
1
  from pydantic import BaseModel
2
 
3
  from trauma.api.common.dto import Paging
4
+ from trauma.api.data.model import EntityModelExtended
5
  from trauma.api.message.model import MessageModel
6
  from trauma.core.wrappers import TraumaResponseWrapper
7
 
 
24
 
25
  class CreateMessageResponse(BaseModel):
26
  text: str
27
+ entities: list[EntityModelExtended] | None = None
trauma/api/message/utils.py CHANGED
@@ -1,5 +1,7 @@
1
  import json
2
 
 
 
3
  from trauma.api.data.model import EntityModel
4
  from trauma.api.message.model import MessageModel
5
 
@@ -35,7 +37,9 @@ def prepare_user_messages_str(user_message: str, messages: list[dict]) -> str:
35
  def prepare_final_entities_str(entities: list[EntityModel]) -> str:
36
  entities_list = []
37
  for entity in entities:
38
- entities_list.append(entity.model_dump(mode='json', exclude={'id', 'contactDetails'}))
 
 
39
  return json.dumps({"klinieken": entities_list})
40
 
41
 
@@ -48,3 +52,15 @@ def pick_empty_field_instructions(empty_field: str) -> str:
48
  return "Het type psychische of lichamelijke ziekte / stoornis."
49
  elif empty_field == "treatmentMethod":
50
  return "Een methode om de ziekte of stoornis te behandelen."
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
 
3
+ from trauma.api.chat.dto import EntityData
4
+ from trauma.api.data.dto import AgeGroup
5
  from trauma.api.data.model import EntityModel
6
  from trauma.api.message.model import MessageModel
7
 
 
37
  def prepare_final_entities_str(entities: list[EntityModel]) -> str:
38
  entities_list = []
39
  for entity in entities:
40
+ entities_list.append(entity.model_dump(mode='json', exclude={
41
+ 'id', 'contactDetails', "highlightedAgeGroup", "highlightedTreatmentArea", "highlightedTreatmentMethod"
42
+ }))
43
  return json.dumps({"klinieken": entities_list})
44
 
45
 
 
52
  return "Het type psychische of lichamelijke ziekte / stoornis."
53
  elif empty_field == "treatmentMethod":
54
  return "Een methode om de ziekte of stoornis te behandelen."
55
+
56
+
57
+ def find_matching_age_group(entity: EntityModel, entity_data: dict) -> AgeGroup:
58
+ age_groups = entity.ageGroups
59
+ best_match = None
60
+ for age_group in age_groups:
61
+ if age_group.ageMin <= entity_data['ageMax'] and age_group.ageMax >= entity_data['ageMin']:
62
+ return age_group
63
+ if age_group.ageMax < entity_data['ageMin']:
64
+ if best_match is None or age_group.ageMax > best_match.ageMax:
65
+ best_match = age_group
66
+ return best_match
trauma/core/config.py CHANGED
@@ -16,7 +16,7 @@ class BaseConfig:
16
  DB_CLIENT = motor.motor_asyncio.AsyncIOMotorClient(os.getenv("MONGO_DB_URL")).AtlasCluster
17
  OPENAI_CLIENT = AsyncClient(api_key=os.getenv('OPENAI_API_KEY'))
18
  SEMANTIC_INDEX = faiss.read_index(str(pathlib.Path(__file__).parent.parent.parent / 'indexes' / 'entities.index'))
19
- INTRO_MESSAGE = """Hello! I am an AI assistant here to help find the perfect clinic for every patient. Please share the patient’s age restrictions."""
20
 
21
  class DevelopmentConfig(BaseConfig):
22
  Issuer = "http://localhost:8000"
 
16
  DB_CLIENT = motor.motor_asyncio.AsyncIOMotorClient(os.getenv("MONGO_DB_URL")).AtlasCluster
17
  OPENAI_CLIENT = AsyncClient(api_key=os.getenv('OPENAI_API_KEY'))
18
  SEMANTIC_INDEX = faiss.read_index(str(pathlib.Path(__file__).parent.parent.parent / 'indexes' / 'entities.index'))
19
+ INTRO_MESSAGE = """Hallo! Ik ben een AI-assistent hier om te helpen bij het vinden van de perfecte kliniek voor elke patiënt. Deel de gegevens van de patiënt."""
20
 
21
  class DevelopmentConfig(BaseConfig):
22
  Issuer = "http://localhost:8000"