Spaces:
Running
Running
added highlights
Browse files- trauma/api/data/dto.py +2 -1
- trauma/api/data/model.py +6 -0
- trauma/api/message/ai/engine.py +45 -5
- trauma/api/message/ai/openai_request.py +27 -1
- trauma/api/message/ai/prompts.py +57 -0
- trauma/api/message/db_requests.py +3 -20
- trauma/api/message/schemas.py +2 -2
- trauma/api/message/utils.py +17 -1
- trauma/core/config.py +1 -1
trauma/api/data/dto.py
CHANGED
@@ -10,4 +10,5 @@ class ContactDetails(BaseModel):
|
|
10 |
email: str | None = None
|
11 |
website: str | None = None
|
12 |
address: str | None = None
|
13 |
-
postalCode: str | None = None
|
|
|
|
10 |
email: str | None = None
|
11 |
website: str | None = None
|
12 |
address: str | None = None
|
13 |
+
postalCode: str | None = None
|
14 |
+
|
trauma/api/data/model.py
CHANGED
@@ -42,3 +42,9 @@ class EntityModel(MongoBaseModel):
|
|
42 |
treatmentMethods: list[str]
|
43 |
description: str = ''
|
44 |
contactDetails: ContactDetails
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
treatmentMethods: list[str]
|
43 |
description: str = ''
|
44 |
contactDetails: ContactDetails
|
45 |
+
|
46 |
+
|
47 |
+
class EntityModelExtended(EntityModel):
|
48 |
+
highlightedAgeGroup: AgeGroup
|
49 |
+
highlightedTreatmentArea: str
|
50 |
+
highlightedTreatmentMethod: str
|
trauma/api/message/ai/engine.py
CHANGED
@@ -1,19 +1,24 @@
|
|
1 |
import asyncio
|
2 |
|
|
|
|
|
|
|
3 |
from trauma.api.chat.model import ChatModel
|
|
|
4 |
from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
|
5 |
generate_next_question,
|
6 |
generate_search_request,
|
7 |
-
generate_final_response
|
|
|
8 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
9 |
filter_entities_by_age,
|
10 |
-
|
11 |
-
update_entity_data_obj)
|
12 |
from trauma.api.message.schemas import CreateMessageResponse
|
13 |
from trauma.api.message.utils import (retrieve_empty_field_from_entity_data,
|
14 |
prepare_user_messages_str,
|
15 |
prepare_final_entities_str,
|
16 |
-
pick_empty_field_instructions)
|
|
|
17 |
|
18 |
|
19 |
async def search_entities(
|
@@ -34,9 +39,44 @@ async def search_entities(
|
|
34 |
filter_entities_by_age(entity_data),
|
35 |
generate_search_request(user_messages_str, entity_data)
|
36 |
)
|
37 |
-
final_entities = await search_semantic_entities(search_request, possible_entity_indexes)
|
38 |
final_entities_str = prepare_final_entities_str(final_entities)
|
39 |
response = await generate_final_response(final_entities_str, user_message, messages)
|
40 |
|
41 |
asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
|
42 |
return CreateMessageResponse(text=response, entities=final_entities)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from trauma.api.chat.dto import EntityData
|
6 |
from trauma.api.chat.model import ChatModel
|
7 |
+
from trauma.api.data.model import EntityModel, EntityModelExtended
|
8 |
from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
|
9 |
generate_next_question,
|
10 |
generate_search_request,
|
11 |
+
generate_final_response, convert_value_to_embeddings,
|
12 |
+
choose_closest_treatment_method, choose_closest_treatment_area)
|
13 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
14 |
filter_entities_by_age,
|
15 |
+
update_entity_data_obj, get_entity_by_index)
|
|
|
16 |
from trauma.api.message.schemas import CreateMessageResponse
|
17 |
from trauma.api.message.utils import (retrieve_empty_field_from_entity_data,
|
18 |
prepare_user_messages_str,
|
19 |
prepare_final_entities_str,
|
20 |
+
pick_empty_field_instructions, find_matching_age_group)
|
21 |
+
from trauma.core.config import settings
|
22 |
|
23 |
|
24 |
async def search_entities(
|
|
|
39 |
filter_entities_by_age(entity_data),
|
40 |
generate_search_request(user_messages_str, entity_data)
|
41 |
)
|
42 |
+
final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
|
43 |
final_entities_str = prepare_final_entities_str(final_entities)
|
44 |
response = await generate_final_response(final_entities_str, user_message, messages)
|
45 |
|
46 |
asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
|
47 |
return CreateMessageResponse(text=response, entities=final_entities)
|
48 |
+
|
49 |
+
|
50 |
+
async def search_semantic_entities(
|
51 |
+
search_request: str, entity_data: EntityData, entities_indexes: list[int]
|
52 |
+
) -> list[EntityModelExtended]:
|
53 |
+
embedding = await convert_value_to_embeddings(search_request)
|
54 |
+
query_embedding = np.array([embedding], dtype=np.float32)
|
55 |
+
distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal)
|
56 |
+
distances = distances[0]
|
57 |
+
indices = indices[0]
|
58 |
+
filtered_results = [
|
59 |
+
{"index": int(idx), "distance": float(dist)}
|
60 |
+
for idx, dist in zip(indices, distances)
|
61 |
+
if idx in entities_indexes and dist <= 1.3
|
62 |
+
]
|
63 |
+
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
|
64 |
+
final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
|
65 |
+
final_entities_extended = await asyncio.gather(
|
66 |
+
*[extended_entity_with_highlights(entity, entity_data) for entity in final_entities]
|
67 |
+
)
|
68 |
+
return final_entities_extended
|
69 |
+
|
70 |
+
|
71 |
+
async def extended_entity_with_highlights(entity: EntityModel, entity_data: dict) -> EntityModelExtended:
|
72 |
+
age_group = find_matching_age_group(entity, entity_data)
|
73 |
+
treatment_area, treatment_method = await asyncio.gather(
|
74 |
+
choose_closest_treatment_area(entity.treatmentAreas, entity_data['treatmentArea']),
|
75 |
+
choose_closest_treatment_method(entity.treatmentMethods, entity_data['treatmentMethod'])
|
76 |
+
)
|
77 |
+
return EntityModelExtended(
|
78 |
+
**entity.to_mongo(),
|
79 |
+
highlightedAgeGroup=age_group,
|
80 |
+
highlightedTreatmentArea=treatment_area,
|
81 |
+
highlightedTreatmentMethod=treatment_method
|
82 |
+
)
|
trauma/api/message/ai/openai_request.py
CHANGED
@@ -74,4 +74,30 @@ async def convert_value_to_embeddings(value: str) -> list[float]:
|
|
74 |
model='text-embedding-3-large',
|
75 |
dimensions=1536,
|
76 |
)
|
77 |
-
return embeddings.data[0].embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
model='text-embedding-3-large',
|
75 |
dimensions=1536,
|
76 |
)
|
77 |
+
return embeddings.data[0].embedding
|
78 |
+
|
79 |
+
|
80 |
+
@openai_wrapper(is_json=True, return_='result')
|
81 |
+
async def choose_closest_treatment_area(treatment_areas: list[str], treatment_area: str):
|
82 |
+
messages = [
|
83 |
+
{
|
84 |
+
"role": "system",
|
85 |
+
"content": TraumaPrompts.choose_closest_treatment_area
|
86 |
+
.replace("{treatment_areas}", ", ".join(treatment_areas))
|
87 |
+
.replace("{treatment_area}", treatment_area)
|
88 |
+
}
|
89 |
+
]
|
90 |
+
return messages
|
91 |
+
|
92 |
+
|
93 |
+
@openai_wrapper(is_json=True, return_='result')
|
94 |
+
async def choose_closest_treatment_method(treatment_methods: list[str], treatment_method: str):
|
95 |
+
messages = [
|
96 |
+
{
|
97 |
+
"role": "system",
|
98 |
+
"content": TraumaPrompts.choose_closest_treatment_method
|
99 |
+
.replace("{treatment_methods}", ", ".join(treatment_methods))
|
100 |
+
.replace("{treatment_method}", treatment_method)
|
101 |
+
}
|
102 |
+
]
|
103 |
+
return messages
|
trauma/api/message/ai/prompts.py
CHANGED
@@ -137,3 +137,60 @@ Je bent verplicht om een beschrijving voor een kliniek te genereren op basis van
|
|
137 |
- De beschrijving moet beknopt en bondig zijn.
|
138 |
|
139 |
[/INST]"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
- De beschrijving moet beknopt en bondig zijn.
|
138 |
|
139 |
[/INST]"""
|
140 |
+
|
141 |
+
choose_closest_treatment_area = """
|
142 |
+
|
143 |
+
## Task
|
144 |
+
|
145 |
+
You must determine the most semantically similar disorder or disease from the list of [treatment areas] to the requested disease [requested treatment area]. The most similar disease should be returned in the [result] field of the JSON.
|
146 |
+
|
147 |
+
## Data
|
148 |
+
|
149 |
+
**treatment areas**:
|
150 |
+
```
|
151 |
+
{treatment_areas}
|
152 |
+
```
|
153 |
+
|
154 |
+
**requested treatment area**:
|
155 |
+
```
|
156 |
+
{treatment_area}
|
157 |
+
```
|
158 |
+
|
159 |
+
## JSON Response format
|
160 |
+
|
161 |
+
```json
|
162 |
+
{
|
163 |
+
"result": "string"
|
164 |
+
}
|
165 |
+
```
|
166 |
+
|
167 |
+
## Instructions for filling JSON
|
168 |
+
|
169 |
+
- [result]: The item from the [treatment areas] list that is most semantically similar to the requested disease. The disease name in the result field must exactly match the name as it appears in the [treatment areas] list."""
|
170 |
+
choose_closest_treatment_method = """## Task
|
171 |
+
|
172 |
+
You must determine the most semantically similar treatment method from the list of [treatment methods] to the requested treatment method [requested treatment method]. The most similar treatment method should be returned in the [result] field of the JSON.
|
173 |
+
|
174 |
+
## Data
|
175 |
+
|
176 |
+
**treatment methods**:
|
177 |
+
```
|
178 |
+
{treatment_methods}
|
179 |
+
```
|
180 |
+
|
181 |
+
**requested treatment method**:
|
182 |
+
```
|
183 |
+
{treatment_method}
|
184 |
+
```
|
185 |
+
|
186 |
+
## JSON Response format
|
187 |
+
|
188 |
+
```json
|
189 |
+
{
|
190 |
+
"result": "string"
|
191 |
+
}
|
192 |
+
```
|
193 |
+
|
194 |
+
## Instructions for filling JSON
|
195 |
+
|
196 |
+
- [result]: The item from the [treatment methods] list that is most semantically similar to the requested treatment method. The treatment method name in the result field must exactly match the name as it appears in the [treatment methods] list."""
|
trauma/api/message/db_requests.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
import asyncio
|
2 |
|
3 |
-
import numpy as np
|
4 |
from fastapi import HTTPException
|
5 |
|
6 |
from trauma.api.chat.model import ChatModel
|
7 |
from trauma.api.data.model import EntityModel
|
8 |
-
from trauma.api.message.ai.openai_request import convert_value_to_embeddings
|
9 |
from trauma.api.message.dto import Author
|
10 |
from trauma.api.message.model import MessageModel
|
11 |
from trauma.api.message.schemas import CreateMessageRequest
|
@@ -58,12 +56,12 @@ async def save_assistant_user_message(user_message: str, assistant_message: str,
|
|
58 |
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
59 |
|
60 |
|
61 |
-
async def filter_entities_by_age(
|
62 |
query = {
|
63 |
"ageGroups": {
|
64 |
"$elemMatch": {
|
65 |
-
"ageMin": {"$lte":
|
66 |
-
"ageMax": {"$gte":
|
67 |
}
|
68 |
}
|
69 |
}
|
@@ -74,18 +72,3 @@ async def filter_entities_by_age(entity: dict) -> list[int]:
|
|
74 |
async def get_entity_by_index(index: int) -> EntityModel:
|
75 |
entity = await settings.DB_CLIENT.entities.find_one({"index": index})
|
76 |
return EntityModel.from_mongo(entity)
|
77 |
-
|
78 |
-
async def search_semantic_entities(search_request: str, entities_indexes: list[int]) -> list[EntityModel]:
|
79 |
-
embedding = await convert_value_to_embeddings(search_request)
|
80 |
-
query_embedding = np.array([embedding], dtype=np.float32)
|
81 |
-
distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal)
|
82 |
-
distances = distances[0]
|
83 |
-
indices = indices[0]
|
84 |
-
filtered_results = [
|
85 |
-
{"index": int(idx), "distance": float(dist)}
|
86 |
-
for idx, dist in zip(indices, distances)
|
87 |
-
if idx in entities_indexes and dist <= 1.3
|
88 |
-
]
|
89 |
-
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
|
90 |
-
final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
|
91 |
-
return final_entities
|
|
|
1 |
import asyncio
|
2 |
|
|
|
3 |
from fastapi import HTTPException
|
4 |
|
5 |
from trauma.api.chat.model import ChatModel
|
6 |
from trauma.api.data.model import EntityModel
|
|
|
7 |
from trauma.api.message.dto import Author
|
8 |
from trauma.api.message.model import MessageModel
|
9 |
from trauma.api.message.schemas import CreateMessageRequest
|
|
|
56 |
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
57 |
|
58 |
|
59 |
+
async def filter_entities_by_age(entity_data: dict) -> list[int]:
|
60 |
query = {
|
61 |
"ageGroups": {
|
62 |
"$elemMatch": {
|
63 |
+
"ageMin": {"$lte": entity_data['ageMax']},
|
64 |
+
"ageMax": {"$gte": entity_data['ageMin']}
|
65 |
}
|
66 |
}
|
67 |
}
|
|
|
72 |
async def get_entity_by_index(index: int) -> EntityModel:
|
73 |
entity = await settings.DB_CLIENT.entities.find_one({"index": index})
|
74 |
return EntityModel.from_mongo(entity)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/message/schemas.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from pydantic import BaseModel
|
2 |
|
3 |
from trauma.api.common.dto import Paging
|
4 |
-
from trauma.api.data.model import
|
5 |
from trauma.api.message.model import MessageModel
|
6 |
from trauma.core.wrappers import TraumaResponseWrapper
|
7 |
|
@@ -24,4 +24,4 @@ class AllMessageWrapper(TraumaResponseWrapper[AllMessageResponse]):
|
|
24 |
|
25 |
class CreateMessageResponse(BaseModel):
|
26 |
text: str
|
27 |
-
entities: list[
|
|
|
1 |
from pydantic import BaseModel
|
2 |
|
3 |
from trauma.api.common.dto import Paging
|
4 |
+
from trauma.api.data.model import EntityModelExtended
|
5 |
from trauma.api.message.model import MessageModel
|
6 |
from trauma.core.wrappers import TraumaResponseWrapper
|
7 |
|
|
|
24 |
|
25 |
class CreateMessageResponse(BaseModel):
|
26 |
text: str
|
27 |
+
entities: list[EntityModelExtended] | None = None
|
trauma/api/message/utils.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import json
|
2 |
|
|
|
|
|
3 |
from trauma.api.data.model import EntityModel
|
4 |
from trauma.api.message.model import MessageModel
|
5 |
|
@@ -35,7 +37,9 @@ def prepare_user_messages_str(user_message: str, messages: list[dict]) -> str:
|
|
35 |
def prepare_final_entities_str(entities: list[EntityModel]) -> str:
|
36 |
entities_list = []
|
37 |
for entity in entities:
|
38 |
-
entities_list.append(entity.model_dump(mode='json', exclude={
|
|
|
|
|
39 |
return json.dumps({"klinieken": entities_list})
|
40 |
|
41 |
|
@@ -48,3 +52,15 @@ def pick_empty_field_instructions(empty_field: str) -> str:
|
|
48 |
return "Het type psychische of lichamelijke ziekte / stoornis."
|
49 |
elif empty_field == "treatmentMethod":
|
50 |
return "Een methode om de ziekte of stoornis te behandelen."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
|
3 |
+
from trauma.api.chat.dto import EntityData
|
4 |
+
from trauma.api.data.dto import AgeGroup
|
5 |
from trauma.api.data.model import EntityModel
|
6 |
from trauma.api.message.model import MessageModel
|
7 |
|
|
|
37 |
def prepare_final_entities_str(entities: list[EntityModel]) -> str:
|
38 |
entities_list = []
|
39 |
for entity in entities:
|
40 |
+
entities_list.append(entity.model_dump(mode='json', exclude={
|
41 |
+
'id', 'contactDetails', "highlightedAgeGroup", "highlightedTreatmentArea", "highlightedTreatmentMethod"
|
42 |
+
}))
|
43 |
return json.dumps({"klinieken": entities_list})
|
44 |
|
45 |
|
|
|
52 |
return "Het type psychische of lichamelijke ziekte / stoornis."
|
53 |
elif empty_field == "treatmentMethod":
|
54 |
return "Een methode om de ziekte of stoornis te behandelen."
|
55 |
+
|
56 |
+
|
57 |
+
def find_matching_age_group(entity: EntityModel, entity_data: dict) -> AgeGroup:
|
58 |
+
age_groups = entity.ageGroups
|
59 |
+
best_match = None
|
60 |
+
for age_group in age_groups:
|
61 |
+
if age_group.ageMin <= entity_data['ageMax'] and age_group.ageMax >= entity_data['ageMin']:
|
62 |
+
return age_group
|
63 |
+
if age_group.ageMax < entity_data['ageMin']:
|
64 |
+
if best_match is None or age_group.ageMax > best_match.ageMax:
|
65 |
+
best_match = age_group
|
66 |
+
return best_match
|
trauma/core/config.py
CHANGED
@@ -16,7 +16,7 @@ class BaseConfig:
|
|
16 |
DB_CLIENT = motor.motor_asyncio.AsyncIOMotorClient(os.getenv("MONGO_DB_URL")).AtlasCluster
|
17 |
OPENAI_CLIENT = AsyncClient(api_key=os.getenv('OPENAI_API_KEY'))
|
18 |
SEMANTIC_INDEX = faiss.read_index(str(pathlib.Path(__file__).parent.parent.parent / 'indexes' / 'entities.index'))
|
19 |
-
INTRO_MESSAGE = """
|
20 |
|
21 |
class DevelopmentConfig(BaseConfig):
|
22 |
Issuer = "http://localhost:8000"
|
|
|
16 |
DB_CLIENT = motor.motor_asyncio.AsyncIOMotorClient(os.getenv("MONGO_DB_URL")).AtlasCluster
|
17 |
OPENAI_CLIENT = AsyncClient(api_key=os.getenv('OPENAI_API_KEY'))
|
18 |
SEMANTIC_INDEX = faiss.read_index(str(pathlib.Path(__file__).parent.parent.parent / 'indexes' / 'entities.index'))
|
19 |
+
INTRO_MESSAGE = """Hallo! Ik ben een AI-assistent hier om te helpen bij het vinden van de perfecte kliniek voor elke patiënt. Deel de gegevens van de patiënt."""
|
20 |
|
21 |
class DevelopmentConfig(BaseConfig):
|
22 |
Issuer = "http://localhost:8000"
|