Spaces:
Running
Running
add location postal code searching
Browse files
trauma/api/account/db_requests.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import asyncio
|
2 |
|
|
|
3 |
from trauma.api.account.model import AccountModel
|
4 |
from trauma.api.account.schemas import CreateAccountRequest
|
5 |
from trauma.api.common.db_requests import check_unique_fields_existence
|
|
|
1 |
import asyncio
|
2 |
|
3 |
+
from trauma.api.account.dto import AccountType
|
4 |
from trauma.api.account.model import AccountModel
|
5 |
from trauma.api.account.schemas import CreateAccountRequest
|
6 |
from trauma.api.common.db_requests import check_unique_fields_existence
|
trauma/api/chat/dto.py
CHANGED
@@ -13,3 +13,5 @@ class EntityData(BaseModel):
|
|
13 |
age: int | None = None
|
14 |
treatmentArea: str | None = None
|
15 |
treatmentMethod: str | None = None
|
|
|
|
|
|
13 |
age: int | None = None
|
14 |
treatmentArea: str | None = None
|
15 |
treatmentMethod: str | None = None
|
16 |
+
location: str | None = None
|
17 |
+
postalCode: str | None = None
|
trauma/api/data/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from fastapi.routing import APIRouter
|
2 |
|
3 |
facility_router = APIRouter(
|
4 |
-
prefix="/api/facility", tags=["
|
5 |
)
|
6 |
|
7 |
from . import views
|
|
|
1 |
from fastapi.routing import APIRouter
|
2 |
|
3 |
facility_router = APIRouter(
|
4 |
+
prefix="/api/facility", tags=["facility"]
|
5 |
)
|
6 |
|
7 |
from . import views
|
trauma/api/data/model.py
CHANGED
@@ -1,39 +1,7 @@
|
|
1 |
-
import datetime
|
2 |
-
|
3 |
from trauma.api.data.dto import AgeGroup, ContactDetails
|
4 |
from trauma.core.database import MongoBaseModel
|
5 |
|
6 |
|
7 |
-
class DataModel(MongoBaseModel):
|
8 |
-
startDate: datetime.datetime
|
9 |
-
endDate: datetime.datetime
|
10 |
-
email: str | None = None
|
11 |
-
name: str | None = None
|
12 |
-
datetimeUpdated: datetime.datetime | None = None
|
13 |
-
organizationName: str | None = None
|
14 |
-
organizationLocation: str
|
15 |
-
youthCareRegion: str
|
16 |
-
postalCode: str
|
17 |
-
hasOrganizationMultipleLocations: bool
|
18 |
-
publicEmail: str
|
19 |
-
website: str
|
20 |
-
isTraumaTreatmentVisible: bool
|
21 |
-
injuryTreatmentForms: list[str]
|
22 |
-
offers: list[str]
|
23 |
-
intensities: list[str]
|
24 |
-
injuryTreatmentDuration: list[str]
|
25 |
-
ageGroups: list[AgeGroup]
|
26 |
-
traumaIndication: list[str]
|
27 |
-
supplier: str
|
28 |
-
treatmentFinancing: list[str]
|
29 |
-
isContractedCare: list[bool]
|
30 |
-
traumaTeamCharacteristics: list[str]
|
31 |
-
traumaDisciplines: list[str]
|
32 |
-
practitionersQualifications: list[str]
|
33 |
-
AQAs: list[str]
|
34 |
-
additions: list[str]
|
35 |
-
|
36 |
-
|
37 |
class EntityModel(MongoBaseModel):
|
38 |
id: str
|
39 |
name: str
|
|
|
|
|
|
|
1 |
from trauma.api.data.dto import AgeGroup, ContactDetails
|
2 |
from trauma.core.database import MongoBaseModel
|
3 |
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
class EntityModel(MongoBaseModel):
|
6 |
id: str
|
7 |
name: str
|
trauma/api/message/ai/engine.py
CHANGED
@@ -12,7 +12,7 @@ from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
|
|
12 |
choose_closest_treatment_method, choose_closest_treatment_area,
|
13 |
check_is_valid_request, generate_invalid_response, set_entity_score)
|
14 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
15 |
-
|
16 |
update_entity_data_obj, get_entity_by_index)
|
17 |
from trauma.api.message.schemas import CreateMessageResponse
|
18 |
from trauma.api.message.utils import (retrieve_empty_field_from_entity_data,
|
@@ -41,7 +41,7 @@ async def search_entities(
|
|
41 |
else:
|
42 |
user_messages_str = prepare_user_messages_str(user_message, messages)
|
43 |
possible_entity_indexes, search_request = await asyncio.gather(
|
44 |
-
|
45 |
generate_search_request(user_messages_str, entity_data)
|
46 |
)
|
47 |
final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
|
@@ -67,12 +67,12 @@ async def search_semantic_entities(
|
|
67 |
]
|
68 |
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
|
69 |
final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
|
70 |
-
final_entities_extended = await
|
71 |
final_entities_scored = await set_entities_score(final_entities_extended, search_request)
|
72 |
return final_entities_scored
|
73 |
|
74 |
|
75 |
-
async def
|
76 |
EntityModelExtended]:
|
77 |
async def choose_closest(entity_: EntityModel) -> tuple:
|
78 |
treatment_area, treatment_method = await asyncio.gather(
|
|
|
12 |
choose_closest_treatment_method, choose_closest_treatment_area,
|
13 |
check_is_valid_request, generate_invalid_response, set_entity_score)
|
14 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
15 |
+
filter_entities_by_age_location,
|
16 |
update_entity_data_obj, get_entity_by_index)
|
17 |
from trauma.api.message.schemas import CreateMessageResponse
|
18 |
from trauma.api.message.utils import (retrieve_empty_field_from_entity_data,
|
|
|
41 |
else:
|
42 |
user_messages_str = prepare_user_messages_str(user_message, messages)
|
43 |
possible_entity_indexes, search_request = await asyncio.gather(
|
44 |
+
filter_entities_by_age_location(entity_data),
|
45 |
generate_search_request(user_messages_str, entity_data)
|
46 |
)
|
47 |
final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
|
|
|
67 |
]
|
68 |
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
|
69 |
final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
|
70 |
+
final_entities_extended = await extend_entities_with_highlights(final_entities, entity_data)
|
71 |
final_entities_scored = await set_entities_score(final_entities_extended, search_request)
|
72 |
return final_entities_scored
|
73 |
|
74 |
|
75 |
+
async def extend_entities_with_highlights(entities: list[EntityModel], entity_data: dict) -> list[
|
76 |
EntityModelExtended]:
|
77 |
async def choose_closest(entity_: EntityModel) -> tuple:
|
78 |
treatment_area, treatment_method = await asyncio.gather(
|
trauma/api/message/ai/prompts.py
CHANGED
@@ -30,13 +30,17 @@ Je verzamelt informatie over een patiënt, hun ziekte en de behandelmethode zoda
|
|
30 |
{
|
31 |
“age”: integer,
|
32 |
“treatmentArea”: “string”,
|
33 |
-
“treatmentMethod”: “string
|
|
|
|
|
34 |
}
|
35 |
```
|
36 |
|
37 |
-
- **
|
38 |
-
- **
|
39 |
-
- **
|
|
|
|
|
40 |
|
41 |
## Regels voor het bijwerken van Entity Data
|
42 |
|
@@ -139,6 +143,7 @@ The field is considered valid (`is_valid = true`) if:
|
|
139 |
- The user describes the patient, their data, illness, treatment method, etc.
|
140 |
- The user's request relates to a medical topic.
|
141 |
- The user's request is a valid response to the assistant's question.
|
|
|
142 |
|
143 |
[/INST]"""
|
144 |
generate_invalid_response = """## Taak
|
|
|
30 |
{
|
31 |
“age”: integer,
|
32 |
“treatmentArea”: “string”,
|
33 |
+
“treatmentMethod”: “string”,
|
34 |
+
"location": "string",
|
35 |
+
"postalCode": "string",
|
36 |
}
|
37 |
```
|
38 |
|
39 |
+
- **age**: leeftijd van de patiënt.
|
40 |
+
- **treatmentArea**: Het type mentale of fysieke ziekte/stoornis.
|
41 |
+
- **treatmentMethod**: Een methode voor het behandelen van de ziekte of stoornis.
|
42 |
+
- **location**: Stad of adres waar de facility zich bevindt
|
43 |
+
- **postalCode**: Postcode van de facility..
|
44 |
|
45 |
## Regels voor het bijwerken van Entity Data
|
46 |
|
|
|
143 |
- The user describes the patient, their data, illness, treatment method, etc.
|
144 |
- The user's request relates to a medical topic.
|
145 |
- The user's request is a valid response to the assistant's question.
|
146 |
+
- The user's request describes desired facility.
|
147 |
|
148 |
[/INST]"""
|
149 |
generate_invalid_response = """## Taak
|
trauma/api/message/db_requests.py
CHANGED
@@ -60,15 +60,26 @@ async def save_assistant_user_message(user_message: str, assistant_message: str,
|
|
60 |
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
61 |
|
62 |
|
63 |
-
async def
|
64 |
query = {
|
65 |
"ageGroups": {
|
66 |
"$elemMatch": {
|
67 |
"ageMin": {"$lte": entity_data['age']},
|
68 |
"ageMax": {"$gte": entity_data['age']}
|
69 |
}
|
70 |
-
}
|
71 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
entities = await settings.DB_CLIENT.entities.find(query).to_list(length=None)
|
73 |
return [entity['index'] for entity in entities]
|
74 |
|
|
|
60 |
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
61 |
|
62 |
|
63 |
+
async def filter_entities_by_age_location(entity_data: dict) -> list[int]:
|
64 |
query = {
|
65 |
"ageGroups": {
|
66 |
"$elemMatch": {
|
67 |
"ageMin": {"$lte": entity_data['age']},
|
68 |
"ageMax": {"$gte": entity_data['age']}
|
69 |
}
|
70 |
+
},
|
71 |
}
|
72 |
+
if entity_data.get('location'):
|
73 |
+
query["contactDetails.address"] = {
|
74 |
+
"$regex": f".*{entity_data['location']}.*",
|
75 |
+
"$options": "i"
|
76 |
+
}
|
77 |
+
|
78 |
+
if entity_data.get('postalCode'):
|
79 |
+
query["contactDetails.postalCode"] = {
|
80 |
+
"$regex": f".*{entity_data['postalCode']}.*",
|
81 |
+
"$options": "i"
|
82 |
+
}
|
83 |
entities = await settings.DB_CLIENT.entities.find(query).to_list(length=None)
|
84 |
return [entity['index'] for entity in entities]
|
85 |
|
trauma/api/message/utils.py
CHANGED
@@ -20,7 +20,7 @@ def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
|
|
20 |
|
21 |
def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
|
22 |
for k, v in entity_data.items():
|
23 |
-
if not v:
|
24 |
return k
|
25 |
return None
|
26 |
|
|
|
20 |
|
21 |
def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
|
22 |
for k, v in entity_data.items():
|
23 |
+
if k not in ('location', 'postalCode') and not v:
|
24 |
return k
|
25 |
return None
|
26 |
|