brestok commited on
Commit
211d915
·
1 Parent(s): aabf8ec

add location postal code searching

Browse files
trauma/api/account/db_requests.py CHANGED
@@ -1,5 +1,6 @@
1
  import asyncio
2
 
 
3
  from trauma.api.account.model import AccountModel
4
  from trauma.api.account.schemas import CreateAccountRequest
5
  from trauma.api.common.db_requests import check_unique_fields_existence
 
1
  import asyncio
2
 
3
+ from trauma.api.account.dto import AccountType
4
  from trauma.api.account.model import AccountModel
5
  from trauma.api.account.schemas import CreateAccountRequest
6
  from trauma.api.common.db_requests import check_unique_fields_existence
trauma/api/chat/dto.py CHANGED
@@ -13,3 +13,5 @@ class EntityData(BaseModel):
13
  age: int | None = None
14
  treatmentArea: str | None = None
15
  treatmentMethod: str | None = None
 
 
 
13
  age: int | None = None
14
  treatmentArea: str | None = None
15
  treatmentMethod: str | None = None
16
+ location: str | None = None
17
+ postalCode: str | None = None
trauma/api/data/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi.routing import APIRouter
2
 
3
  facility_router = APIRouter(
4
- prefix="/api/facility", tags=["chat"]
5
  )
6
 
7
  from . import views
 
1
  from fastapi.routing import APIRouter
2
 
3
  facility_router = APIRouter(
4
+ prefix="/api/facility", tags=["facility"]
5
  )
6
 
7
  from . import views
trauma/api/data/model.py CHANGED
@@ -1,39 +1,7 @@
1
- import datetime
2
-
3
  from trauma.api.data.dto import AgeGroup, ContactDetails
4
  from trauma.core.database import MongoBaseModel
5
 
6
 
7
- class DataModel(MongoBaseModel):
8
- startDate: datetime.datetime
9
- endDate: datetime.datetime
10
- email: str | None = None
11
- name: str | None = None
12
- datetimeUpdated: datetime.datetime | None = None
13
- organizationName: str | None = None
14
- organizationLocation: str
15
- youthCareRegion: str
16
- postalCode: str
17
- hasOrganizationMultipleLocations: bool
18
- publicEmail: str
19
- website: str
20
- isTraumaTreatmentVisible: bool
21
- injuryTreatmentForms: list[str]
22
- offers: list[str]
23
- intensities: list[str]
24
- injuryTreatmentDuration: list[str]
25
- ageGroups: list[AgeGroup]
26
- traumaIndication: list[str]
27
- supplier: str
28
- treatmentFinancing: list[str]
29
- isContractedCare: list[bool]
30
- traumaTeamCharacteristics: list[str]
31
- traumaDisciplines: list[str]
32
- practitionersQualifications: list[str]
33
- AQAs: list[str]
34
- additions: list[str]
35
-
36
-
37
  class EntityModel(MongoBaseModel):
38
  id: str
39
  name: str
 
 
 
1
  from trauma.api.data.dto import AgeGroup, ContactDetails
2
  from trauma.core.database import MongoBaseModel
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class EntityModel(MongoBaseModel):
6
  id: str
7
  name: str
trauma/api/message/ai/engine.py CHANGED
@@ -12,7 +12,7 @@ from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
12
  choose_closest_treatment_method, choose_closest_treatment_area,
13
  check_is_valid_request, generate_invalid_response, set_entity_score)
14
  from trauma.api.message.db_requests import (save_assistant_user_message,
15
- filter_entities_by_age,
16
  update_entity_data_obj, get_entity_by_index)
17
  from trauma.api.message.schemas import CreateMessageResponse
18
  from trauma.api.message.utils import (retrieve_empty_field_from_entity_data,
@@ -41,7 +41,7 @@ async def search_entities(
41
  else:
42
  user_messages_str = prepare_user_messages_str(user_message, messages)
43
  possible_entity_indexes, search_request = await asyncio.gather(
44
- filter_entities_by_age(entity_data),
45
  generate_search_request(user_messages_str, entity_data)
46
  )
47
  final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
@@ -67,12 +67,12 @@ async def search_semantic_entities(
67
  ]
68
  filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
69
  final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
70
- final_entities_extended = await extended_entities_with_highlights(final_entities, entity_data)
71
  final_entities_scored = await set_entities_score(final_entities_extended, search_request)
72
  return final_entities_scored
73
 
74
 
75
- async def extended_entities_with_highlights(entities: list[EntityModel], entity_data: dict) -> list[
76
  EntityModelExtended]:
77
  async def choose_closest(entity_: EntityModel) -> tuple:
78
  treatment_area, treatment_method = await asyncio.gather(
 
12
  choose_closest_treatment_method, choose_closest_treatment_area,
13
  check_is_valid_request, generate_invalid_response, set_entity_score)
14
  from trauma.api.message.db_requests import (save_assistant_user_message,
15
+ filter_entities_by_age_location,
16
  update_entity_data_obj, get_entity_by_index)
17
  from trauma.api.message.schemas import CreateMessageResponse
18
  from trauma.api.message.utils import (retrieve_empty_field_from_entity_data,
 
41
  else:
42
  user_messages_str = prepare_user_messages_str(user_message, messages)
43
  possible_entity_indexes, search_request = await asyncio.gather(
44
+ filter_entities_by_age_location(entity_data),
45
  generate_search_request(user_messages_str, entity_data)
46
  )
47
  final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
 
67
  ]
68
  filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
69
  final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
70
+ final_entities_extended = await extend_entities_with_highlights(final_entities, entity_data)
71
  final_entities_scored = await set_entities_score(final_entities_extended, search_request)
72
  return final_entities_scored
73
 
74
 
75
+ async def extend_entities_with_highlights(entities: list[EntityModel], entity_data: dict) -> list[
76
  EntityModelExtended]:
77
  async def choose_closest(entity_: EntityModel) -> tuple:
78
  treatment_area, treatment_method = await asyncio.gather(
trauma/api/message/ai/prompts.py CHANGED
@@ -30,13 +30,17 @@ Je verzamelt informatie over een patiënt, hun ziekte en de behandelmethode zoda
30
  {
31
  “age”: integer,
32
  “treatmentArea”: “string”,
33
- “treatmentMethod”: “string
 
 
34
  }
35
  ```
36
 
37
- - **[age]**: leeftijd van de patiënt.
38
- - **[treatmentArea]**: Het type mentale of fysieke ziekte/stoornis.
39
- - **[treatmentMethod]**: Een methode voor het behandelen van de ziekte of stoornis.
 
 
40
 
41
  ## Regels voor het bijwerken van Entity Data
42
 
@@ -139,6 +143,7 @@ The field is considered valid (`is_valid = true`) if:
139
  - The user describes the patient, their data, illness, treatment method, etc.
140
  - The user's request relates to a medical topic.
141
  - The user's request is a valid response to the assistant's question.
 
142
 
143
  [/INST]"""
144
  generate_invalid_response = """## Taak
 
30
  {
31
  “age”: integer,
32
  “treatmentArea”: “string”,
33
+ “treatmentMethod”: “string”,
34
+ "location": "string",
35
+ "postalCode": "string",
36
  }
37
  ```
38
 
39
+ - **age**: leeftijd van de patiënt.
40
+ - **treatmentArea**: Het type mentale of fysieke ziekte/stoornis.
41
+ - **treatmentMethod**: Een methode voor het behandelen van de ziekte of stoornis.
42
+ - **location**: Stad of adres waar de facility zich bevindt
43
+ - **postalCode**: Postcode van de facility..
44
 
45
  ## Regels voor het bijwerken van Entity Data
46
 
 
143
  - The user describes the patient, their data, illness, treatment method, etc.
144
  - The user's request relates to a medical topic.
145
  - The user's request is a valid response to the assistant's question.
146
+ - The user's request describes desired facility.
147
 
148
  [/INST]"""
149
  generate_invalid_response = """## Taak
trauma/api/message/db_requests.py CHANGED
@@ -60,15 +60,26 @@ async def save_assistant_user_message(user_message: str, assistant_message: str,
60
  await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
61
 
62
 
63
- async def filter_entities_by_age(entity_data: dict) -> list[int]:
64
  query = {
65
  "ageGroups": {
66
  "$elemMatch": {
67
  "ageMin": {"$lte": entity_data['age']},
68
  "ageMax": {"$gte": entity_data['age']}
69
  }
70
- }
71
  }
 
 
 
 
 
 
 
 
 
 
 
72
  entities = await settings.DB_CLIENT.entities.find(query).to_list(length=None)
73
  return [entity['index'] for entity in entities]
74
 
 
60
  await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
61
 
62
 
63
+ async def filter_entities_by_age_location(entity_data: dict) -> list[int]:
64
  query = {
65
  "ageGroups": {
66
  "$elemMatch": {
67
  "ageMin": {"$lte": entity_data['age']},
68
  "ageMax": {"$gte": entity_data['age']}
69
  }
70
+ },
71
  }
72
+ if entity_data.get('location'):
73
+ query["contactDetails.address"] = {
74
+ "$regex": f".*{entity_data['location']}.*",
75
+ "$options": "i"
76
+ }
77
+
78
+ if entity_data.get('postalCode'):
79
+ query["contactDetails.postalCode"] = {
80
+ "$regex": f".*{entity_data['postalCode']}.*",
81
+ "$options": "i"
82
+ }
83
  entities = await settings.DB_CLIENT.entities.find(query).to_list(length=None)
84
  return [entity['index'] for entity in entities]
85
 
trauma/api/message/utils.py CHANGED
@@ -20,7 +20,7 @@ def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
20
 
21
  def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
22
  for k, v in entity_data.items():
23
- if not v:
24
  return k
25
  return None
26
 
 
20
 
21
  def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
22
  for k, v in entity_data.items():
23
+ if k not in ('location', 'postalCode') and not v:
24
  return k
25
  return None
26