brestok commited on
Commit
97743c1
·
1 Parent(s): 20faa08

add top match

Browse files
trauma/api/data/model.py CHANGED
@@ -46,5 +46,7 @@ class EntityModel(MongoBaseModel):
46
 
47
  class EntityModelExtended(EntityModel):
48
  highlightedAgeGroup: AgeGroup
49
- highlightedTreatmentArea: str
50
- highlightedTreatmentMethod: str
 
 
 
46
 
47
  class EntityModelExtended(EntityModel):
48
  highlightedAgeGroup: AgeGroup
49
+ highlightedTreatmentArea: str | None
50
+ highlightedTreatmentMethod: str | None
51
+ topMatch: bool = False
52
+ score: float | None = None
trauma/api/message/ai/engine.py CHANGED
@@ -10,7 +10,7 @@ from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
10
  generate_search_request,
11
  generate_final_response, convert_value_to_embeddings,
12
  choose_closest_treatment_method, choose_closest_treatment_area,
13
- check_is_valid_request, generate_invalid_response)
14
  from trauma.api.message.db_requests import (save_assistant_user_message,
15
  filter_entities_by_age,
16
  update_entity_data_obj, get_entity_by_index)
@@ -67,21 +67,40 @@ async def search_semantic_entities(
67
  ]
68
  filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
69
  final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
70
- final_entities_extended = await asyncio.gather(
71
- *[extended_entity_with_highlights(entity, entity_data) for entity in final_entities]
72
- )
73
- return final_entities_extended
74
 
75
 
76
- async def extended_entity_with_highlights(entity: EntityModel, entity_data: dict) -> EntityModelExtended:
77
- age_group = find_matching_age_group(entity, entity_data)
78
- treatment_area, treatment_method = await asyncio.gather(
79
- choose_closest_treatment_area(entity.treatmentAreas, entity_data['treatmentArea']),
80
- choose_closest_treatment_method(entity.treatmentMethods, entity_data['treatmentMethod'])
81
- )
82
- return EntityModelExtended(
83
- **entity.to_mongo(),
84
- highlightedAgeGroup=age_group,
85
- highlightedTreatmentArea=treatment_area,
86
- highlightedTreatmentMethod=treatment_method
87
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  generate_search_request,
11
  generate_final_response, convert_value_to_embeddings,
12
  choose_closest_treatment_method, choose_closest_treatment_area,
13
+ check_is_valid_request, generate_invalid_response, set_entity_score)
14
  from trauma.api.message.db_requests import (save_assistant_user_message,
15
  filter_entities_by_age,
16
  update_entity_data_obj, get_entity_by_index)
 
67
  ]
68
  filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
69
  final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
70
+ final_entities_extended = await extended_entities_with_highlights(final_entities, entity_data)
71
+ final_entities_scored = await set_entities_score(final_entities_extended, search_request)
72
+ return final_entities_scored
 
73
 
74
 
75
+ async def extended_entities_with_highlights(entities: list[EntityModel], entity_data: dict) -> list[
76
+ EntityModelExtended]:
77
+ async def choose_closest(entity_: EntityModel) -> tuple:
78
+ treatment_area, treatment_method = await asyncio.gather(
79
+ choose_closest_treatment_area(entity_.treatmentAreas, entity_data['treatmentArea']),
80
+ choose_closest_treatment_method(entity_.treatmentMethods, entity_data['treatmentMethod'])
81
+ )
82
+ return treatment_area, treatment_method
83
+
84
+ results = await asyncio.gather(*[choose_closest(entity) for entity in entities])
85
+ final_entities = []
86
+ for treatment, entity in zip(results, entities):
87
+ age_group = find_matching_age_group(entity, entity_data)
88
+ final_entities.append(EntityModelExtended(
89
+ **entity.to_mongo(),
90
+ highlightedAgeGroup=age_group,
91
+ highlightedTreatmentArea=treatment[0],
92
+ highlightedTreatmentMethod=treatment[1]
93
+ ))
94
+ return final_entities
95
+
96
+
97
+ async def set_entities_score(entities: list[EntityModelExtended], search_request: str) -> list[EntityModelExtended]:
98
+ scores = await asyncio.gather(*[set_entity_score(entity, search_request) for entity in entities])
99
+ final_entities = []
100
+ for score, entity in zip(scores, entities):
101
+ if score > 0.9:
102
+ entity.topMatch = True
103
+ entity.score = score
104
+ if score > 0.75:
105
+ final_entities.append(entity)
106
+ return sorted(final_entities, key=lambda x: x.score, reverse=True)
trauma/api/message/ai/openai_request.py CHANGED
@@ -1,6 +1,7 @@
1
  import json
2
 
3
  from trauma.api.chat.dto import EntityData
 
4
  from trauma.api.message.ai.prompts import TraumaPrompts
5
  from trauma.core.config import settings
6
  from trauma.core.wrappers import openai_wrapper
@@ -134,3 +135,18 @@ async def generate_invalid_response(user_message: str, message_history: list[dic
134
  }
135
  ]
136
  return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
 
3
  from trauma.api.chat.dto import EntityData
4
+ from trauma.api.data.model import EntityModelExtended
5
  from trauma.api.message.ai.prompts import TraumaPrompts
6
  from trauma.core.config import settings
7
  from trauma.core.wrappers import openai_wrapper
 
135
  }
136
  ]
137
  return messages
138
+
139
+
140
+ @openai_wrapper(is_json=True, return_='score')
141
+ async def set_entity_score(entity: EntityModelExtended, search_request: str):
142
+ messages = [
143
+ {
144
+ "role": "system",
145
+ "content": TraumaPrompts.set_entity_score
146
+ .replace("{entity}", entity.model_dump_json(exclude={
147
+ "ageGroups", "treatmentAreas", "treatmentMethods", "contactDetails"
148
+ }))
149
+ .replace("{search_request}", search_request)
150
+ }
151
+ ]
152
+ return messages
trauma/api/message/ai/prompts.py CHANGED
@@ -193,7 +193,7 @@ Je bent verplicht om een beschrijving voor een kliniek te genereren op basis van
193
 
194
  choose_closest_treatment_area = """## Task
195
 
196
- You must determine the most semantically similar disorder or disease from the list of [treatment areas] to the requested disease [requested treatment area]. The most similar disease should be returned in the [result] field of the JSON.
197
 
198
  ## Data
199
 
@@ -217,10 +217,11 @@ You must determine the most semantically similar disorder or disease from the li
217
 
218
  ## Instructions for filling JSON
219
 
220
- - [result]: The item from the [treatment areas] list that is most semantically similar to the requested disease. The disease name in the result field must exactly match the name as it appears in the [treatment areas] list."""
 
221
  choose_closest_treatment_method = """## Task
222
 
223
- You must determine the most semantically similar treatment method from the list of [treatment methods] to the requested treatment method [requested treatment method]. The most similar treatment method should be returned in the [result] field of the JSON.
224
 
225
  ## Data
226
 
@@ -244,5 +245,41 @@ You must determine the most semantically similar treatment method from the list
244
 
245
  ## Instructions for filling JSON
246
 
247
- - [result]: The item from the [treatment methods] list that is most semantically similar to the requested treatment method. The treatment method name in the result field must exactly match the name as it appears in the [treatment methods] list."""
 
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  choose_closest_treatment_area = """## Task
195
 
196
+ You must determine the most semantically similar disorder or disease from the list of [treatment areas] to the requested disease [requested treatment area]. The most similar disease should be returned in the [result] field of the JSON. If there is no similar disease, you must save `null`.
197
 
198
  ## Data
199
 
 
217
 
218
  ## Instructions for filling JSON
219
 
220
+ - [result]: The item from the [treatment areas] list that is most semantically similar to the requested disease. The disease name in the result field must exactly match the name as it appears in the [treatment areas] list. If there is no similar element from [treatment areas], you must save `null`.
221
+ """
222
  choose_closest_treatment_method = """## Task
223
 
224
+ You must determine the most semantically similar treatment method from the list of [treatment methods] to the requested treatment method [requested treatment method]. The most similar treatment method should be returned in the [result] field of the JSON. If there is no similar treatment method, you must save `null`.
225
 
226
  ## Data
227
 
 
245
 
246
  ## Instructions for filling JSON
247
 
248
+ - [result]: The item from the [treatment methods] list that is most semantically similar to the requested treatment method. The treatment method name in the result field must exactly match the name as it appears in the [treatment methods] list. If there is no similar element from [treatment methods], you must save `null`."""
249
+ set_entity_score = """## Task
250
 
251
+ You must to assign a **relevance score** to a facility based on a given search request. The relevance score should range between **0.00 and 1.00**, where **1.00** indicates a perfect match and **0.00** indicates no relevance.
252
+
253
+ ## Evaluation Criteria
254
+
255
+ Analyze the entire `Facility` object with a focus on the following key fields:
256
+ - **highlightedAgeGroup**: The target age group that the facility serves.
257
+ - **highlightedTreatmentArea**: The primary area of treatment provided by the facility.
258
+ - **highlightedTreatmentMethod**: The main treatment method used at the facility.
259
+ - **description**: Any additional text that may indicate relevance to the search request.
260
+
261
+ ## Scoring Guidelines
262
+
263
+ - Assign **higher scores** when the `highlightedAgeGroup`, `highlightedTreatmentArea`, `highlightedTreatmentMethod`, and `description` closely **match** the user’s query.
264
+ - Apply **penalties** for mismatches, partial overlaps, or missing key attributes.
265
+
266
+ ## Input
267
+
268
+ **Search request**
269
+ ```
270
+ {search_request}
271
+ ```
272
+
273
+ **Facility**:
274
+ ```json
275
+ {entity}
276
+ ```
277
+
278
+ ## **Output Format**
279
+ Your response must be in the following JSON format:
280
+ ```json
281
+ {
282
+ "score": float
283
+ }
284
+ ```
285
+ - **score**: A floating-point number between **0.00 and 1.00**, representing the degree of relevance."""