Spaces:
Running
Running
fix
Browse files- test.py +8 -23
- trauma/api/data/data/prepare_data.py +23 -1
- trauma/api/data/db_requests.py +2 -2
- trauma/api/data/model.py +3 -3
- trauma/api/message/ai/engine.py +9 -2
- trauma/api/message/ai/openai_request.py +40 -2
- trauma/api/message/ai/prompts.py +19 -0
- trauma/api/message/db_requests.py +1 -1
- trauma/api/message/utils.py +0 -1
test.py
CHANGED
@@ -1,28 +1,13 @@
|
|
1 |
-
import
|
2 |
-
from translate import Translator
|
3 |
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
# Читаем файл Excel
|
8 |
-
data = pd.read_excel(input_file, sheet_name=sheet_name)
|
9 |
|
10 |
-
# Инициализируем переводчик
|
11 |
-
translator = Translator(from_lang='nl', to_lang='en')
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
data.rename(columns=translated_columns, inplace=True)
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
print(f"Файл успешно конвертирован и сохранен: {output_file}")
|
20 |
-
except Exception as e:
|
21 |
-
print(f"Произошла ошибка: {e}")
|
22 |
-
|
23 |
-
|
24 |
-
input_xlsx = "test.xlsx" # Путь к входному .xlsx файлу
|
25 |
-
output_csv = "translated_output.csv" # Путь к выходному .csv файлу
|
26 |
-
sheet = "Sheet1" # Укажите имя листа, если нужно
|
27 |
-
|
28 |
-
convert_and_translate_headers(input_xlsx, output_csv, sheet)
|
|
|
1 |
+
import asyncio
|
|
|
2 |
|
3 |
+
import numpy as np
|
4 |
|
5 |
+
from trauma.api.message.ai.openai_request import convert_value_to_embeddings
|
6 |
+
from trauma.core.config import settings
|
|
|
|
|
7 |
|
|
|
|
|
8 |
|
9 |
+
async def main():
|
10 |
+
entities = await settings.DB_CLIENT
|
|
|
11 |
|
12 |
+
if __name__ == '__main__':
|
13 |
+
asyncio.run(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trauma/api/data/data/prepare_data.py
CHANGED
@@ -130,5 +130,27 @@ async def generate_descriptions():
|
|
130 |
)
|
131 |
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
if __name__ == '__main__':
|
134 |
-
asyncio.run(
|
|
|
130 |
)
|
131 |
|
132 |
|
133 |
+
def split_array(array, max_size=10):
|
134 |
+
if max_size <= 0:
|
135 |
+
raise ValueError("max_size must be greater than 0")
|
136 |
+
return [array[i:i + max_size] for i in range(0, len(array), max_size)]
|
137 |
+
|
138 |
+
|
139 |
+
async def generate_embeddings():
|
140 |
+
entities = await settings.DB_CLIENT.entities.find({}, {"embedding": 0}).to_list()
|
141 |
+
entities = [EntityModel.from_mongo(entity) for entity in entities]
|
142 |
+
for entity in entities:
|
143 |
+
entity_str = entity.name
|
144 |
+
entity_emb = await settings.OPENAI_CLIENT.embeddings.create(
|
145 |
+
model='text-embedding-3-large',
|
146 |
+
dimensions=384,
|
147 |
+
input=entity_str,
|
148 |
+
)
|
149 |
+
await settings.DB_CLIENT.entities.update_one(
|
150 |
+
{"id": entity.id},
|
151 |
+
{"$set": {"embedding": entity_emb.data[0].embedding}},
|
152 |
+
)
|
153 |
+
print('hi')
|
154 |
+
|
155 |
if __name__ == '__main__':
|
156 |
+
asyncio.run(generate_embeddings())
|
trauma/api/data/db_requests.py
CHANGED
@@ -5,7 +5,7 @@ from trauma.core.config import settings
|
|
5 |
|
6 |
|
7 |
async def get_facility_by_id(facility_id: str) -> EntityModel:
|
8 |
-
facility = await settings.DB_CLIENT.entities.find_one({"id": facility_id})
|
9 |
if not facility:
|
10 |
raise HTTPException(status_code=404, detail="Country with specified id doesn't exists.")
|
11 |
return EntityModel.from_mongo(facility)
|
@@ -13,5 +13,5 @@ async def get_facility_by_id(facility_id: str) -> EntityModel:
|
|
13 |
|
14 |
async def get_all_model_obj() -> list[EntityModel]:
|
15 |
sort_v = -1
|
16 |
-
objects = await settings.DB_CLIENT.entities.find().sort("_id", sort_v).to_list(length=None)
|
17 |
return objects
|
|
|
5 |
|
6 |
|
7 |
async def get_facility_by_id(facility_id: str) -> EntityModel:
|
8 |
+
facility = await settings.DB_CLIENT.entities.find_one({"id": facility_id}, {"embedding": 0})
|
9 |
if not facility:
|
10 |
raise HTTPException(status_code=404, detail="Country with specified id doesn't exists.")
|
11 |
return EntityModel.from_mongo(facility)
|
|
|
13 |
|
14 |
async def get_all_model_obj() -> list[EntityModel]:
|
15 |
sort_v = -1
|
16 |
+
objects = await settings.DB_CLIENT.entities.find({}, {"embedding": 0}).sort("_id", sort_v).to_list(length=None)
|
17 |
return objects
|
trauma/api/data/model.py
CHANGED
@@ -13,8 +13,8 @@ class EntityModel(MongoBaseModel):
|
|
13 |
|
14 |
|
15 |
class EntityModelExtended(EntityModel):
|
16 |
-
highlightedAgeGroup: AgeGroup
|
17 |
-
highlightedTreatmentArea: str | None
|
18 |
-
highlightedTreatmentMethod: str | None
|
19 |
topMatch: bool = False
|
20 |
score: float | None = None
|
|
|
13 |
|
14 |
|
15 |
class EntityModelExtended(EntityModel):
|
16 |
+
highlightedAgeGroup: AgeGroup | None = None
|
17 |
+
highlightedTreatmentArea: str | None = None
|
18 |
+
highlightedTreatmentMethod: str | None = None
|
19 |
topMatch: bool = False
|
20 |
score: float | None = None
|
trauma/api/message/ai/engine.py
CHANGED
@@ -10,7 +10,8 @@ from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
|
|
10 |
generate_search_request,
|
11 |
generate_final_response, convert_value_to_embeddings,
|
12 |
choose_closest_treatment_method, choose_closest_treatment_area,
|
13 |
-
check_is_valid_request, generate_invalid_response, set_entity_score
|
|
|
14 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
15 |
filter_entities_by_age_location,
|
16 |
update_entity_data_obj, get_entity_by_index)
|
@@ -25,10 +26,16 @@ from trauma.core.config import settings
|
|
25 |
async def search_entities(
|
26 |
user_message: str, messages: list[dict], chat: ChatModel
|
27 |
) -> CreateMessageResponse:
|
28 |
-
entity_data, is_valid = await asyncio.gather(
|
|
|
29 |
update_entity_data_with_ai(chat.entityData, user_message, messages[-1]['content']),
|
30 |
check_is_valid_request(user_message, "\n".join([f"- [{i['role']}]: {i['content']}." for i in messages]))
|
31 |
)
|
|
|
|
|
|
|
|
|
|
|
32 |
final_entities = None
|
33 |
if not is_valid:
|
34 |
response = await generate_invalid_response(user_message, messages)
|
|
|
10 |
generate_search_request,
|
11 |
generate_final_response, convert_value_to_embeddings,
|
12 |
choose_closest_treatment_method, choose_closest_treatment_area,
|
13 |
+
check_is_valid_request, generate_invalid_response, set_entity_score,
|
14 |
+
retrieve_semantic_answer, generate_searched_entity_response)
|
15 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
16 |
filter_entities_by_age_location,
|
17 |
update_entity_data_obj, get_entity_by_index)
|
|
|
26 |
async def search_entities(
|
27 |
user_message: str, messages: list[dict], chat: ChatModel
|
28 |
) -> CreateMessageResponse:
|
29 |
+
related_entity, entity_data, is_valid = await asyncio.gather(
|
30 |
+
retrieve_semantic_answer(user_message),
|
31 |
update_entity_data_with_ai(chat.entityData, user_message, messages[-1]['content']),
|
32 |
check_is_valid_request(user_message, "\n".join([f"- [{i['role']}]: {i['content']}." for i in messages]))
|
33 |
)
|
34 |
+
if related_entity:
|
35 |
+
response = await generate_searched_entity_response(user_message, related_entity[0])
|
36 |
+
asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
|
37 |
+
return CreateMessageResponse(text=response, entities=related_entity)
|
38 |
+
|
39 |
final_entities = None
|
40 |
if not is_valid:
|
41 |
response = await generate_invalid_response(user_message, messages)
|
trauma/api/message/ai/openai_request.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import json
|
2 |
|
3 |
from trauma.api.chat.dto import EntityData
|
@@ -73,11 +74,11 @@ async def generate_final_response(final_entities: str, user_message: str, messag
|
|
73 |
return messages
|
74 |
|
75 |
|
76 |
-
async def convert_value_to_embeddings(value: str) -> list[float]:
|
77 |
embeddings = await settings.OPENAI_CLIENT.embeddings.create(
|
78 |
input=value,
|
79 |
model='text-embedding-3-large',
|
80 |
-
dimensions=
|
81 |
)
|
82 |
return embeddings.data[0].embedding
|
83 |
|
@@ -150,3 +151,40 @@ async def set_entity_score(entity: EntityModelExtended, search_request: str):
|
|
150 |
}
|
151 |
]
|
152 |
return messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
import json
|
3 |
|
4 |
from trauma.api.chat.dto import EntityData
|
|
|
74 |
return messages
|
75 |
|
76 |
|
77 |
+
async def convert_value_to_embeddings(value: str, dimensions: int = 1536) -> list[float]:
|
78 |
embeddings = await settings.OPENAI_CLIENT.embeddings.create(
|
79 |
input=value,
|
80 |
model='text-embedding-3-large',
|
81 |
+
dimensions=dimensions,
|
82 |
)
|
83 |
return embeddings.data[0].embedding
|
84 |
|
|
|
151 |
}
|
152 |
]
|
153 |
return messages
|
154 |
+
|
155 |
+
|
156 |
+
async def retrieve_semantic_answer(user_query: str) -> list[EntityModelExtended] | None:
|
157 |
+
embedding = await settings.OPENAI_CLIENT.embeddings.create(input=user_query,
|
158 |
+
model='text-embedding-3-large',
|
159 |
+
dimensions=384)
|
160 |
+
response = await settings.DB_CLIENT.entities.aggregate([
|
161 |
+
{"$vectorSearch": {
|
162 |
+
"index": f"entityVectors",
|
163 |
+
"path": "embedding",
|
164 |
+
"queryVector": embedding.data[0].embedding,
|
165 |
+
"numCandidates": 20,
|
166 |
+
"limit": 1
|
167 |
+
}},
|
168 |
+
{"$project": {
|
169 |
+
"embedding": 0,
|
170 |
+
"score": {"$meta": "vectorSearchScore"}
|
171 |
+
}}
|
172 |
+
]).to_list(length=1)
|
173 |
+
return [EntityModelExtended(**response[0])] if response[0]['score'] > 0.83 else None
|
174 |
+
|
175 |
+
|
176 |
+
@openai_wrapper()
|
177 |
+
async def generate_searched_entity_response(user_query: str, facility: EntityModelExtended):
|
178 |
+
messages = [
|
179 |
+
{
|
180 |
+
"role": "system",
|
181 |
+
"content": TraumaPrompts.generate_searched_entity
|
182 |
+
.replace("{user_query}", user_query)
|
183 |
+
.replace("{entity}", facility.model_dump_json(indent=2))
|
184 |
+
}
|
185 |
+
]
|
186 |
+
return messages
|
187 |
+
|
188 |
+
|
189 |
+
if __name__ == '__main__':
|
190 |
+
asyncio.run(retrieve_semantic_answer('I want to know more about Praktijk Hermens'))
|
trauma/api/message/ai/prompts.py
CHANGED
@@ -167,6 +167,25 @@ De gebruiker zoekt naar een geschikte kliniek voor een patiënt en deelt hierbij
|
|
167 |
- Gebruik een vriendelijke en geruststellende toon, bijvoorbeeld: "Ik heb op basis van de ingevoerde gegevens geen kliniek kunnen vinden. Kunnen we samen kijken of we de informatie iets kunnen aanpassen om betere resultaten te krijgen?"
|
168 |
- Geef praktische suggesties, zoals: "Misschien helpt het om meer details over de locatie of de behandelmethode te delen."
|
169 |
- Stel open vragen om de gebruiker te begeleiden bij het verfijnen van de gegevens, zoals: "Zijn er andere belangrijke punten die we kunnen toevoegen?"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
generate_clinic_description = """## Taak
|
172 |
|
|
|
167 |
- Gebruik een vriendelijke en geruststellende toon, bijvoorbeeld: "Ik heb op basis van de ingevoerde gegevens geen kliniek kunnen vinden. Kunnen we samen kijken of we de informatie iets kunnen aanpassen om betere resultaten te krijgen?"
|
168 |
- Geef praktische suggesties, zoals: "Misschien helpt het om meer details over de locatie of de behandelmethode te delen."
|
169 |
- Stel open vragen om de gebruiker te begeleiden bij het verfijnen van de gegevens, zoals: "Zijn er andere belangrijke punten die we kunnen toevoegen?"""
|
170 |
+
generate_searched_entity = """## Taak
|
171 |
+
|
172 |
+
Je moet de gevraagde faciliteit beschrijven, waarvan de informatie wordt gegeven in de sectie `## Data`. Analyseer de gebruikersvraag en de informatie over de faciliteit, en geef een beknopt en bondig antwoord.
|
173 |
+
|
174 |
+
## Gegevens
|
175 |
+
|
176 |
+
**Gebruikersvraag**:
|
177 |
+
```
|
178 |
+
{user_query}
|
179 |
+
```
|
180 |
+
|
181 |
+
**Faciliteit**:
|
182 |
+
```
|
183 |
+
{entity}
|
184 |
+
```
|
185 |
+
|
186 |
+
## Belangrijke opmerkingen
|
187 |
+
|
188 |
+
- Je antwoord moet beknopt zijn."""
|
189 |
|
190 |
generate_clinic_description = """## Taak
|
191 |
|
trauma/api/message/db_requests.py
CHANGED
@@ -80,7 +80,7 @@ async def filter_entities_by_age_location(entity_data: dict) -> list[int]:
|
|
80 |
"$regex": f".*{entity_data['postalCode']}.*",
|
81 |
"$options": "i"
|
82 |
}
|
83 |
-
entities = await settings.DB_CLIENT.entities.find(query).to_list(length=None)
|
84 |
return [entity['index'] for entity in entities]
|
85 |
|
86 |
|
|
|
80 |
"$regex": f".*{entity_data['postalCode']}.*",
|
81 |
"$options": "i"
|
82 |
}
|
83 |
+
entities = await settings.DB_CLIENT.entities.find(query, {"embedding": 0}).to_list(length=None)
|
84 |
return [entity['index'] for entity in entities]
|
85 |
|
86 |
|
trauma/api/message/utils.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import json
|
2 |
|
3 |
-
from trauma.api.chat.dto import EntityData
|
4 |
from trauma.api.data.dto import AgeGroup
|
5 |
from trauma.api.data.model import EntityModel
|
6 |
from trauma.api.message.model import MessageModel
|
|
|
1 |
import json
|
2 |
|
|
|
3 |
from trauma.api.data.dto import AgeGroup
|
4 |
from trauma.api.data.model import EntityModel
|
5 |
from trauma.api.message.model import MessageModel
|