Spaces:
Sleeping
Sleeping
Commit
·
c6f267d
1
Parent(s):
8704dff
optimized code
Browse files- REFACTORED_README.md +463 -0
- ai_med_extract/agents/patient_summary_agent.py +313 -1534
- ai_med_extract/api/model_management.py +397 -0
- ai_med_extract/api/routes.py +172 -386
- ai_med_extract/app.py +19 -92
- ai_med_extract/utils/model_config.py +165 -0
- ai_med_extract/utils/model_manager.py +408 -0
- test_refactored_system.py +321 -0
REFACTORED_README.md
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HNTAI Medical Data Extraction - Refactored System
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
This project has been completely refactored to provide a unified, flexible model management system that supports **any model name and type**, including GGUF models for patient summary generation. The system now offers dynamic model loading, runtime model switching, and robust fallback mechanisms.
|
6 |
+
|
7 |
+
## 🚀 Key Features
|
8 |
+
|
9 |
+
### ✨ **Universal Model Support**
|
10 |
+
- **Any Model Name**: Use any Hugging Face model, local model, or custom model
|
11 |
+
- **Any Model Type**: Support for text-generation, summarization, NER, GGUF, OpenVINO, and more
|
12 |
+
- **Automatic Type Detection**: The system automatically detects model types from names
|
13 |
+
- **Dynamic Loading**: Load models at runtime without restarting the application
|
14 |
+
|
15 |
+
### 🔄 **GGUF Model Integration**
|
16 |
+
- **Seamless GGUF Support**: Full integration with llama.cpp for GGUF models
|
17 |
+
- **Patient Summary Generation**: Optimized for medical text summarization
|
18 |
+
- **Memory Efficient**: Ultra-conservative settings for Hugging Face Spaces
|
19 |
+
- **Fallback Mechanisms**: Automatic fallback when GGUF models fail
|
20 |
+
|
21 |
+
### 🧠 **Unified Model Manager**
|
22 |
+
- **Single Interface**: One manager handles all model types
|
23 |
+
- **Smart Caching**: Intelligent model caching with memory management
|
24 |
+
- **Fallback Chains**: Multiple fallback options for robustness
|
25 |
+
- **Performance Monitoring**: Built-in timing and memory tracking
|
26 |
+
|
27 |
+
## 🏗️ Architecture
|
28 |
+
|
29 |
+
### Core Components
|
30 |
+
|
31 |
+
1. **`UnifiedModelManager`** - Central model management system
|
32 |
+
2. **`BaseModelLoader`** - Abstract interface for all model loaders
|
33 |
+
3. **`TransformersModelLoader`** - Hugging Face Transformers models
|
34 |
+
4. **`GGUFModelLoader`** - GGUF models via llama.cpp
|
35 |
+
5. **`OpenVINOModelLoader`** - OpenVINO optimized models
|
36 |
+
6. **`PatientSummarizerAgent`** - Enhanced patient summary generation
|
37 |
+
|
38 |
+
### Model Type Support
|
39 |
+
|
40 |
+
| Model Type | Description | Example Models |
|
41 |
+
|------------|-------------|----------------|
|
42 |
+
| `text-generation` | Causal language models | `facebook/bart-base`, `microsoft/DialoGPT-medium` |
|
43 |
+
| `summarization` | Text summarization models | `Falconsai/medical_summarization`, `facebook/bart-large-cnn` |
|
44 |
+
| `ner` | Named Entity Recognition | `dslim/bert-base-NER`, `Jean-Baptiste/roberta-large-ner-english` |
|
45 |
+
| `gguf` | GGUF format models | `microsoft/Phi-3-mini-4k-instruct-gguf` |
|
46 |
+
| `openvino` | OpenVINO optimized models | `microsoft/Phi-3-mini-4k-instruct` |
|
47 |
+
|
48 |
+
## 🚀 Quick Start
|
49 |
+
|
50 |
+
### 1. Basic Usage
|
51 |
+
|
52 |
+
```python
|
53 |
+
from ai_med_extract.utils.model_manager import model_manager
|
54 |
+
|
55 |
+
# Load any model dynamically
|
56 |
+
loader = model_manager.get_model_loader(
|
57 |
+
model_name="microsoft/Phi-3-mini-4k-instruct-gguf",
|
58 |
+
model_type="gguf",
|
59 |
+
filename="Phi-3-mini-4k-instruct-q4.gguf"
|
60 |
+
)
|
61 |
+
|
62 |
+
# Generate text
|
63 |
+
result = loader.generate("Generate a medical summary for...")
|
64 |
+
```
|
65 |
+
|
66 |
+
### 2. Patient Summary Generation
|
67 |
+
|
68 |
+
```python
|
69 |
+
from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent
|
70 |
+
|
71 |
+
# Create agent with any model
|
72 |
+
agent = PatientSummarizerAgent(
|
73 |
+
model_name="microsoft/Phi-3-mini-4k-instruct-gguf",
|
74 |
+
model_type="gguf"
|
75 |
+
)
|
76 |
+
|
77 |
+
# Generate clinical summary
|
78 |
+
summary = agent.generate_clinical_summary(patient_data)
|
79 |
+
```
|
80 |
+
|
81 |
+
### 3. Runtime Model Switching
|
82 |
+
|
83 |
+
```python
|
84 |
+
# Switch models at runtime
|
85 |
+
agent.update_model(
|
86 |
+
model_name="Falconsai/medical_summarization",
|
87 |
+
model_type="summarization"
|
88 |
+
)
|
89 |
+
```
|
90 |
+
|
91 |
+
## 📡 API Endpoints
|
92 |
+
|
93 |
+
### Model Management API
|
94 |
+
|
95 |
+
#### Load Model
|
96 |
+
```http
|
97 |
+
POST /api/models/load
|
98 |
+
Content-Type: application/json
|
99 |
+
|
100 |
+
{
|
101 |
+
"model_name": "microsoft/Phi-3-mini-4k-instruct-gguf",
|
102 |
+
"model_type": "gguf",
|
103 |
+
"filename": "Phi-3-mini-4k-instruct-q4.gguf",
|
104 |
+
"force_reload": false
|
105 |
+
}
|
106 |
+
```
|
107 |
+
|
108 |
+
#### Generate Text
|
109 |
+
```http
|
110 |
+
POST /api/models/generate
|
111 |
+
Content-Type: application/json
|
112 |
+
|
113 |
+
{
|
114 |
+
"model_name": "microsoft/Phi-3-mini-4k-instruct-gguf",
|
115 |
+
"model_type": "gguf",
|
116 |
+
"prompt": "Generate a medical summary for...",
|
117 |
+
"max_tokens": 512,
|
118 |
+
"temperature": 0.7
|
119 |
+
}
|
120 |
+
```
|
121 |
+
|
122 |
+
#### Switch Agent Model
|
123 |
+
```http
|
124 |
+
POST /api/models/switch
|
125 |
+
Content-Type: application/json
|
126 |
+
|
127 |
+
{
|
128 |
+
"agent_name": "patient_summarizer",
|
129 |
+
"model_name": "microsoft/Phi-3-mini-4k-instruct-gguf",
|
130 |
+
"model_type": "gguf"
|
131 |
+
}
|
132 |
+
```
|
133 |
+
|
134 |
+
#### Get Model Information
|
135 |
+
```http
|
136 |
+
GET /api/models/info?model_name=microsoft/Phi-3-mini-4k-instruct-gguf
|
137 |
+
```
|
138 |
+
|
139 |
+
#### Health Check
|
140 |
+
```http
|
141 |
+
GET /api/models/health
|
142 |
+
```
|
143 |
+
|
144 |
+
### Patient Summary API
|
145 |
+
|
146 |
+
#### Generate Patient Summary
|
147 |
+
```http
|
148 |
+
POST /generate_patient_summary
|
149 |
+
Content-Type: application/json
|
150 |
+
|
151 |
+
{
|
152 |
+
"patientid": "12345",
|
153 |
+
"token": "your_token",
|
154 |
+
"key": "your_api_key",
|
155 |
+
"patient_summarizer_model_name": "microsoft/Phi-3-mini-4k-instruct-gguf",
|
156 |
+
"patient_summarizer_model_type": "gguf"
|
157 |
+
}
|
158 |
+
```
|
159 |
+
|
160 |
+
## 🔧 Configuration
|
161 |
+
|
162 |
+
### Environment Variables
|
163 |
+
|
164 |
+
```bash
|
165 |
+
# Cache directories
|
166 |
+
HF_HOME=/tmp/huggingface
|
167 |
+
XDG_CACHE_HOME=/tmp
|
168 |
+
TORCH_HOME=/tmp/torch
|
169 |
+
WHISPER_CACHE=/tmp/whisper
|
170 |
+
|
171 |
+
# GGUF optimization
|
172 |
+
GGUF_N_THREADS=2
|
173 |
+
GGUF_N_BATCH=64
|
174 |
+
```
|
175 |
+
|
176 |
+
### Model Configuration
|
177 |
+
|
178 |
+
The system automatically uses optimized models for different environments:
|
179 |
+
|
180 |
+
- **Local Development**: Full model capabilities
|
181 |
+
- **Hugging Face Spaces**: Memory-optimized models
|
182 |
+
- **Production**: Configurable based on resources
|
183 |
+
|
184 |
+
## 🎯 Use Cases
|
185 |
+
|
186 |
+
### 1. **Medical Document Processing**
|
187 |
+
```python
|
188 |
+
# Extract medical data with any model
|
189 |
+
medical_data = model_manager.generate_text(
|
190 |
+
model_name="facebook/bart-base",
|
191 |
+
model_type="text-generation",
|
192 |
+
prompt="Extract medical entities from: " + document_text
|
193 |
+
)
|
194 |
+
```
|
195 |
+
|
196 |
+
### 2. **Patient Summary Generation**
|
197 |
+
```python
|
198 |
+
# Use GGUF model for patient summaries
|
199 |
+
summary = model_manager.generate_text(
|
200 |
+
model_name="microsoft/Phi-3-mini-4k-instruct-gguf",
|
201 |
+
model_type="gguf",
|
202 |
+
prompt=patient_data_prompt,
|
203 |
+
max_tokens=512
|
204 |
+
)
|
205 |
+
```
|
206 |
+
|
207 |
+
### 3. **Dynamic Model Switching**
|
208 |
+
```python
|
209 |
+
# Switch between models based on task requirements
|
210 |
+
if task == "summarization":
|
211 |
+
model_name = "Falconsai/medical_summarization"
|
212 |
+
model_type = "summarization"
|
213 |
+
elif task == "extraction":
|
214 |
+
model_name = "facebook/bart-base"
|
215 |
+
model_type = "text-generation"
|
216 |
+
|
217 |
+
loader = model_manager.get_model_loader(model_name, model_type)
|
218 |
+
```
|
219 |
+
|
220 |
+
## 🔒 Memory Management
|
221 |
+
|
222 |
+
### Hugging Face Spaces Optimization
|
223 |
+
|
224 |
+
The system automatically detects Hugging Face Spaces and applies ultra-conservative memory settings:
|
225 |
+
|
226 |
+
- **GGUF Models**: 1 thread, 16 batch size, 512 context
|
227 |
+
- **Transformers**: Float32 precision, minimal memory usage
|
228 |
+
- **Automatic Fallbacks**: Graceful degradation when memory is limited
|
229 |
+
|
230 |
+
### Memory Monitoring
|
231 |
+
|
232 |
+
```python
|
233 |
+
# Check memory usage
|
234 |
+
health = requests.get("/api/models/health").json()
|
235 |
+
print(f"GPU Memory: {health['gpu_info']['memory_allocated']}")
|
236 |
+
print(f"Loaded Models: {health['loaded_models_count']}")
|
237 |
+
```
|
238 |
+
|
239 |
+
## 🧪 Testing
|
240 |
+
|
241 |
+
### Test GGUF Models
|
242 |
+
|
243 |
+
```bash
|
244 |
+
# Test GGUF model loading
|
245 |
+
python test_gguf.py
|
246 |
+
|
247 |
+
# Test specific model
|
248 |
+
python -c "
|
249 |
+
from ai_med_extract.utils.model_manager import model_manager
|
250 |
+
loader = model_manager.get_model_loader('microsoft/Phi-3-mini-4k-instruct-gguf', 'gguf')
|
251 |
+
result = loader.generate('Test prompt')
|
252 |
+
print(f'Success: {len(result)} characters generated')
|
253 |
+
"
|
254 |
+
```
|
255 |
+
|
256 |
+
### Model Validation
|
257 |
+
|
258 |
+
```python
|
259 |
+
from ai_med_extract.utils.model_config import validate_model_config
|
260 |
+
|
261 |
+
# Validate model configuration
|
262 |
+
validation = validate_model_config(
|
263 |
+
model_name="microsoft/Phi-3-mini-4k-instruct-gguf",
|
264 |
+
model_type="gguf"
|
265 |
+
)
|
266 |
+
|
267 |
+
print(f"Valid: {validation['valid']}")
|
268 |
+
print(f"Warnings: {validation['warnings']}")
|
269 |
+
```
|
270 |
+
|
271 |
+
## 🚨 Error Handling
|
272 |
+
|
273 |
+
### Fallback Mechanisms
|
274 |
+
|
275 |
+
1. **Primary Model**: Attempts to load the specified model
|
276 |
+
2. **Fallback Model**: Uses predefined fallback for the model type
|
277 |
+
3. **Text Fallback**: Generates structured text responses
|
278 |
+
4. **Graceful Degradation**: Continues operation with reduced functionality
|
279 |
+
|
280 |
+
### Common Issues
|
281 |
+
|
282 |
+
#### GGUF Model Loading Fails
|
283 |
+
```python
|
284 |
+
# Check model file
|
285 |
+
if not os.path.exists(model_path):
|
286 |
+
# Download from Hugging Face
|
287 |
+
from huggingface_hub import hf_hub_download
|
288 |
+
model_path = hf_hub_download(repo_id, filename)
|
289 |
+
```
|
290 |
+
|
291 |
+
#### Memory Issues
|
292 |
+
```python
|
293 |
+
# Clear cache and reload
|
294 |
+
model_manager.clear_cache()
|
295 |
+
torch.cuda.empty_cache()
|
296 |
+
|
297 |
+
# Use smaller model
|
298 |
+
loader = model_manager.get_model_loader(
|
299 |
+
model_name="facebook/bart-base", # Smaller model
|
300 |
+
model_type="text-generation"
|
301 |
+
)
|
302 |
+
```
|
303 |
+
|
304 |
+
## 📊 Performance
|
305 |
+
|
306 |
+
### Benchmarking
|
307 |
+
|
308 |
+
```python
|
309 |
+
import time
|
310 |
+
|
311 |
+
# Time model loading
|
312 |
+
start = time.time()
|
313 |
+
loader = model_manager.get_model_loader(model_name, model_type)
|
314 |
+
load_time = time.time() - start
|
315 |
+
|
316 |
+
# Time generation
|
317 |
+
start = time.time()
|
318 |
+
result = loader.generate(prompt)
|
319 |
+
gen_time = time.time() - start
|
320 |
+
|
321 |
+
print(f"Load: {load_time:.2f}s, Generate: {gen_time:.2f}s")
|
322 |
+
```
|
323 |
+
|
324 |
+
### Optimization Tips
|
325 |
+
|
326 |
+
1. **Use Appropriate Model Size**: Smaller models for limited resources
|
327 |
+
2. **Enable Caching**: Models are cached after first load
|
328 |
+
3. **Batch Processing**: Process multiple requests together
|
329 |
+
4. **Memory Monitoring**: Regular health checks
|
330 |
+
|
331 |
+
## 🔮 Future Enhancements
|
332 |
+
|
333 |
+
### Planned Features
|
334 |
+
|
335 |
+
- **Model Quantization**: Automatic model optimization
|
336 |
+
- **Distributed Loading**: Load models across multiple devices
|
337 |
+
- **Model Versioning**: Track and manage model versions
|
338 |
+
- **Performance Analytics**: Detailed performance metrics
|
339 |
+
- **Auto-scaling**: Automatic model scaling based on load
|
340 |
+
|
341 |
+
### Extensibility
|
342 |
+
|
343 |
+
The system is designed for easy extension:
|
344 |
+
|
345 |
+
```python
|
346 |
+
class CustomModelLoader(BaseModelLoader):
|
347 |
+
def __init__(self, model_name: str):
|
348 |
+
self.model_name = model_name
|
349 |
+
|
350 |
+
def load(self):
|
351 |
+
# Custom loading logic
|
352 |
+
pass
|
353 |
+
|
354 |
+
def generate(self, prompt: str, **kwargs):
|
355 |
+
# Custom generation logic
|
356 |
+
pass
|
357 |
+
```
|
358 |
+
|
359 |
+
## 📝 Migration Guide
|
360 |
+
|
361 |
+
### From Old System
|
362 |
+
|
363 |
+
1. **Replace Hardcoded Models**:
|
364 |
+
```python
|
365 |
+
# Old
|
366 |
+
model = LazyModelLoader("facebook/bart-base", "text-generation")
|
367 |
+
|
368 |
+
# New
|
369 |
+
model = model_manager.get_model_loader("facebook/bart-base", "text-generation")
|
370 |
+
```
|
371 |
+
|
372 |
+
2. **Update Patient Summarizer**:
|
373 |
+
```python
|
374 |
+
# Old
|
375 |
+
agent = PatientSummarizerAgent()
|
376 |
+
|
377 |
+
# New
|
378 |
+
agent = PatientSummarizerAgent(
|
379 |
+
model_name="microsoft/Phi-3-mini-4k-instruct-gguf",
|
380 |
+
model_type="gguf"
|
381 |
+
)
|
382 |
+
```
|
383 |
+
|
384 |
+
3. **Use Dynamic Model Selection**:
|
385 |
+
```python
|
386 |
+
# Old: Fixed model types
|
387 |
+
# New: Dynamic model selection
|
388 |
+
model_type = request.form.get("model_type", "text-generation")
|
389 |
+
model_name = request.form.get("model_name", "facebook/bart-base")
|
390 |
+
```
|
391 |
+
|
392 |
+
## 🤝 Contributing
|
393 |
+
|
394 |
+
### Development Setup
|
395 |
+
|
396 |
+
```bash
|
397 |
+
# Clone repository
|
398 |
+
git clone <repository-url>
|
399 |
+
cd HNTAI
|
400 |
+
|
401 |
+
# Install dependencies
|
402 |
+
pip install -r requirements.txt
|
403 |
+
|
404 |
+
# Run tests
|
405 |
+
python -m pytest tests/
|
406 |
+
|
407 |
+
# Start development server
|
408 |
+
python -m ai_med_extract.app
|
409 |
+
```
|
410 |
+
|
411 |
+
### Adding New Model Types
|
412 |
+
|
413 |
+
1. **Create Loader Class**:
|
414 |
+
```python
|
415 |
+
class CustomModelLoader(BaseModelLoader):
|
416 |
+
# Implement required methods
|
417 |
+
pass
|
418 |
+
```
|
419 |
+
|
420 |
+
2. **Update Model Manager**:
|
421 |
+
```python
|
422 |
+
if model_type == "custom":
|
423 |
+
loader = CustomModelLoader(model_name)
|
424 |
+
```
|
425 |
+
|
426 |
+
3. **Add Configuration**:
|
427 |
+
```python
|
428 |
+
DEFAULT_MODELS["custom"] = {
|
429 |
+
"primary": "default/custom-model",
|
430 |
+
"fallback": "fallback/custom-model"
|
431 |
+
}
|
432 |
+
```
|
433 |
+
|
434 |
+
## 📄 License
|
435 |
+
|
436 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
437 |
+
|
438 |
+
## 🆘 Support
|
439 |
+
|
440 |
+
### Getting Help
|
441 |
+
|
442 |
+
- **Documentation**: This README and inline code comments
|
443 |
+
- **Issues**: GitHub Issues for bug reports
|
444 |
+
- **Discussions**: GitHub Discussions for questions
|
445 |
+
- **Examples**: See `test_gguf.py` and other test files
|
446 |
+
|
447 |
+
### Common Questions
|
448 |
+
|
449 |
+
**Q: Can I use my own GGUF model?**
|
450 |
+
A: Yes! Just provide the path to your .gguf file or upload it to Hugging Face.
|
451 |
+
|
452 |
+
**Q: How do I optimize for memory?**
|
453 |
+
A: Use smaller models, enable caching, and monitor memory usage via `/api/models/health`.
|
454 |
+
|
455 |
+
**Q: Can I switch models without restarting?**
|
456 |
+
A: Yes! Use the `/api/models/switch` endpoint to change models at runtime.
|
457 |
+
|
458 |
+
**Q: What if a model fails to load?**
|
459 |
+
A: The system automatically falls back to alternative models and provides detailed error information.
|
460 |
+
|
461 |
+
---
|
462 |
+
|
463 |
+
**🎉 Congratulations!** You now have a powerful, flexible system that can work with any model name and type, including GGUF models for patient summary generation. The system is designed to be robust, efficient, and easy to use while maintaining backward compatibility.
|
ai_med_extract/agents/patient_summary_agent.py
CHANGED
@@ -1,1559 +1,338 @@
|
|
1 |
-
|
2 |
-
# # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
-
# # import torch
|
4 |
-
|
5 |
-
# # from ai_med_extract.utils.patient_summary_utils import patient_chunk_text, flatten_to_string_list
|
6 |
-
|
7 |
-
# # class PatientSummarizerAgent:
|
8 |
-
# # def __init__(self):
|
9 |
-
# # # Device configuration
|
10 |
-
# # self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
11 |
-
# # if self.device == 'cuda':
|
12 |
-
# # torch.cuda.empty_cache()
|
13 |
-
|
14 |
-
# # # Load medical summarization model
|
15 |
-
# # self.MODEL_NAME = "Falconsai/medical_summarization"
|
16 |
-
# # try:
|
17 |
-
# # self.tokenizer, self.model = self.load_model(self.MODEL_NAME, self.device)
|
18 |
-
# # except RuntimeError as e:
|
19 |
-
# # exit()
|
20 |
-
|
21 |
-
# # def load_model(self, model_name: str, device: str):
|
22 |
-
# # try:
|
23 |
-
# # tokenizer = AutoTokenizer.from_pretrained(model_name)
|
24 |
-
# # model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
25 |
-
# # if device == 'cuda':
|
26 |
-
# # model = model.half()
|
27 |
-
# # model.to(device)
|
28 |
-
# # model.eval()
|
29 |
-
# # return tokenizer, model
|
30 |
-
# # except Exception as e:
|
31 |
-
# # raise RuntimeError(f"Model loading failed: {str(e)}")
|
32 |
-
|
33 |
-
# # def summarize_chunk(self, text):
|
34 |
-
# # try:
|
35 |
-
# # inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
|
36 |
-
# # outputs = self.model.generate(
|
37 |
-
# # **inputs,
|
38 |
-
# # max_new_tokens=400,
|
39 |
-
# # num_beams=4,
|
40 |
-
# # temperature=0.7,
|
41 |
-
# # early_stopping=True
|
42 |
-
# # )
|
43 |
-
# # return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
44 |
-
# # except Exception as e:
|
45 |
-
# # return f"Error summarizing chunk: {str(e)}"
|
46 |
-
|
47 |
-
# # def generate_clinical_summary(self, patient_data: dict) -> str:
|
48 |
-
# # try:
|
49 |
-
# # # Use flattened data and chunking for summarization
|
50 |
-
# # flattened_lines = flatten_to_string_list(patient_data)
|
51 |
-
# # chunks = patient_chunk_text(flattened_lines, chunk_size=1500)
|
52 |
-
# # chunk_summaries = [self.summarize_chunk(chunk) for chunk in chunks]
|
53 |
-
# # raw_summary = " ".join(chunk_summaries)
|
54 |
-
# # return self.format_clinical_output(raw_summary, patient_data)
|
55 |
-
# # except Exception as e:
|
56 |
-
# # return f"Error generating summary: {str(e)}"
|
57 |
-
|
58 |
-
# # def format_clinical_output(self, raw_summary: str, patient_data: dict) -> str:
|
59 |
-
# # current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
60 |
-
# # result = patient_data['result']
|
61 |
-
# # formatted = f"\n--- CLINICAL DECISION SUMMARY ---\n"
|
62 |
-
# # formatted += f"Summary Generated On: {current_time}\n"
|
63 |
-
|
64 |
-
# # # Demographics
|
65 |
-
# # formatted += f"\n--- PATIENT DEMOGRAPHICS ---\n"
|
66 |
-
# # formatted += f"Patient ID: {result.get('patientnumber', 'Unknown')}\n"
|
67 |
-
# # gender = result.get('gender', 'Unknown')
|
68 |
-
# # formatted += f"Age/Sex: {result.get('agey', 'Unknown')} {gender[0] if gender and gender != 'Unknown' else 'U'}\n"
|
69 |
-
# # formatted += f"Date of Birth: {result.get('dob', 'N/A')}\n"
|
70 |
-
# # formatted += f"Blood Group: {result.get('bloodgrp', 'N/A')}\n"
|
71 |
-
# # formatted += f"Last Visit Date: {result.get('lastvisitdt', 'N/A')}\n"
|
72 |
-
|
73 |
-
# # allergies = result.get('allergies') or ['None known']
|
74 |
-
# # formatted += f"Allergies: **{', '.join(allergies)}**\n"
|
75 |
-
# # formatted += f"Social History: {result.get('social_history', 'Not specified')}\n"
|
76 |
-
|
77 |
-
# # # Reason for visit
|
78 |
-
# # formatted += f"\n--- REASON FOR VISIT ---\n"
|
79 |
-
# # formatted += f"Chief Complaint: **{result.get('chief_complaint', 'Not specified')}**\n"
|
80 |
-
|
81 |
-
# # # Past medical history
|
82 |
-
# # formatted += f"\n--- PAST MEDICAL HISTORY ---\n"
|
83 |
-
# # past_history = result.get('past_medical_history') or ['None']
|
84 |
-
# # for item in past_history:
|
85 |
-
# # formatted += f"{item}\n"
|
86 |
-
|
87 |
-
# # # Vitals
|
88 |
-
# # formatted += f"\n--- CURRENT VITALS ---\n"
|
89 |
-
# # vitals = result.get('vitals', {})
|
90 |
-
# # formatted += f"BP: {vitals.get('BP', 'N/A')}\n"
|
91 |
-
# # formatted += f"Temp: {vitals.get('Temp', 'N/A')}\n"
|
92 |
-
# # formatted += f"SpO2: {vitals.get('SpO2', 'N/A')}\n"
|
93 |
-
# # formatted += f"Height: {vitals.get('Height', 'N/A')}\n"
|
94 |
-
# # formatted += f"Weight: {vitals.get('Weight', 'N/A')}\n"
|
95 |
-
# # formatted += f"BMI: {vitals.get('BMI', 'N/A')}\n"
|
96 |
-
|
97 |
-
# # # Lab & Imaging
|
98 |
-
# # formatted += f"\n--- LAB & IMAGING ---\n"
|
99 |
-
# # formatted += f"\n**Lab Tests Results:**\n"
|
100 |
-
# # lab_results = result.get('lab_results') or []
|
101 |
-
# # if lab_results:
|
102 |
-
# # for lab in lab_results:
|
103 |
-
# # value = lab.get('value', 'N/A')
|
104 |
-
# # test_name = lab.get('name', 'Unknown Test')
|
105 |
-
# # formatted += f"{test_name}: **{value}**\n"
|
106 |
-
# # else:
|
107 |
-
# # labtests = result.get('labtests') or ['None']
|
108 |
-
# # for test in labtests:
|
109 |
-
# # formatted += f"{test}\n"
|
110 |
-
|
111 |
-
# # formatted += f"\n**Radiology Orders:**\n"
|
112 |
-
# # radiology_orders = result.get('radiologyorders') or ['None']
|
113 |
-
# # for order in radiology_orders:
|
114 |
-
# # formatted += f"{order}\n"
|
115 |
-
|
116 |
-
# # # Medications
|
117 |
-
# # formatted += f"\n--- CURRENT MEDICATIONS ---\n"
|
118 |
-
# # medications = result.get('medications') or ['None']
|
119 |
-
# # for med in medications:
|
120 |
-
# # if med and str(med).lower() != 'null':
|
121 |
-
# # formatted += f"{med}\n"
|
122 |
-
|
123 |
-
# # # Diagnoses
|
124 |
-
# # formatted += f"\n--- ASSESSMENT & DIAGNOSES ---\n"
|
125 |
-
# # diagnoses = result.get('diagnosis') or ['None']
|
126 |
-
# # for dx in diagnoses:
|
127 |
-
# # formatted += f"{dx}\n"
|
128 |
-
|
129 |
-
# # # Plan
|
130 |
-
# # formatted += f"\n--- PLAN ---\n"
|
131 |
-
# # plan = result.get('assessment_plan', 'No plan specified')
|
132 |
-
# # plan_lines = [line.strip() for line in plan.split('\n') if line.strip()]
|
133 |
-
# # for line in plan_lines:
|
134 |
-
# # formatted += f"{line}\n"
|
135 |
-
|
136 |
-
# # # Follow-up
|
137 |
-
# # formatted += f"\n--- FOLLOW-UP RECOMMENDATIONS ---\n"
|
138 |
-
# # formatted += "Re-evaluate in 5-7 days if not improving\n"
|
139 |
-
# # formatted += "Return immediately for worsening dyspnea or new symptoms\n"
|
140 |
-
|
141 |
-
# # formatted += f"\n--- MODEL-GENERATED SUMMARY ---\n{raw_summary}\n"
|
142 |
-
# # return formatted
|
143 |
-
|
144 |
-
|
145 |
-
# # import datetime
|
146 |
-
# # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
147 |
-
# # import torch
|
148 |
-
|
149 |
-
# # from ai_med_extract.utils.patient_summary_utils import patient_chunk_text, flatten_to_string_list
|
150 |
-
|
151 |
-
# # class PatientSummarizerAgent:
|
152 |
-
# # def __init__(self):
|
153 |
-
# # self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
154 |
-
# # self.MODEL_NAME = "Falconsai/medical_summarization" # Or replace with Flan-T5
|
155 |
-
# # self.tokenizer, self.model = self.load_model(self.MODEL_NAME)
|
156 |
-
|
157 |
-
# # def load_model(self, model_name):
|
158 |
-
# # tokenizer = AutoTokenizer.from_pretrained(model_name)
|
159 |
-
# # model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
|
160 |
-
# # model.eval()
|
161 |
-
# # return tokenizer, model
|
162 |
-
|
163 |
-
# # def build_narrative_prompt(self, patient_data):
|
164 |
-
# # result = patient_data['result']
|
165 |
-
# # prompt_lines = [f"Past Medical History: {', '.join(result.get('past_medical_history', []))}.\n"]
|
166 |
-
|
167 |
-
# # for enc in result.get('encounters', []):
|
168 |
-
# # prompt_lines.append(
|
169 |
-
# # f"Encounter on {enc['visit_date']}:\n"
|
170 |
-
# # f"- Chief Complaint: {enc.get('chief_complaint')}\n"
|
171 |
-
# # f"- Symptoms: {enc.get('symptoms')}\n"
|
172 |
-
# # f"- Diagnoses: {', '.join(enc.get('diagnosis', []))}\n"
|
173 |
-
# # f"- Doctor's Notes: {enc.get('dr_notes')}\n"
|
174 |
-
# # f"- Investigations: {enc.get('investigations')}\n"
|
175 |
-
# # f"- Medications: {', '.join(enc.get('medications', []))}\n"
|
176 |
-
# # f"- Treatment: {enc.get('treatment')}\n"
|
177 |
-
# # )
|
178 |
-
|
179 |
-
# # return (
|
180 |
-
# # "Summarize the following clinical timeline with a narrative, assessment, plan, and possible "
|
181 |
-
# # "next steps.\n\nPATIENT HISTORY:\n" + "\n".join(prompt_lines)
|
182 |
-
# # )
|
183 |
-
|
184 |
-
# # def generate_summary(self, prompt: str):
|
185 |
-
# # inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
|
186 |
-
# # outputs = self.model.generate(**inputs, max_new_tokens=512, num_beams=4, early_stopping=True)
|
187 |
-
# # return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
188 |
-
|
189 |
-
# # def generate_clinical_summary(self, patient_data: dict) -> str:
|
190 |
-
# # try:
|
191 |
-
# # prompt = self.build_narrative_prompt(patient_data)
|
192 |
-
# # summary = self.generate_summary(prompt)
|
193 |
-
# # return self.format_clinical_output(summary, patient_data)
|
194 |
-
# # except Exception as e:
|
195 |
-
# # return f"❌ Error generating summary: {e}"
|
196 |
-
|
197 |
-
# # def format_clinical_output(self, raw_summary: str, patient_data: dict) -> str:
|
198 |
-
# # result = patient_data['result']
|
199 |
-
# # now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
200 |
-
# # report = "\n--- CLINICAL SUMMARY ---\n"
|
201 |
-
# # report += f"Generated On: {now}\n\n"
|
202 |
-
# # report += f"Patient ID: {result.get('patientnumber', 'N/A')}\n"
|
203 |
-
# # report += f"Age/Sex: {result.get('agey', 'N/A')} / {result.get('gender', 'N/A')}\n"
|
204 |
-
# # report += f"Allergies: {', '.join(result.get('allergies', ['None']))}\n"
|
205 |
-
# # report += f"\n--- MODEL-GENERATED SUMMARY ---\n"
|
206 |
-
# # report += raw_summary + "\n"
|
207 |
-
# # return report
|
208 |
-
|
209 |
-
|
210 |
-
# import datetime
|
211 |
-
# import torch
|
212 |
-
# import warnings
|
213 |
-
# import re
|
214 |
-
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
215 |
-
# from textwrap import fill
|
216 |
-
|
217 |
-
# # Suppress non-critical warnings
|
218 |
-
# warnings.filterwarnings("ignore", category=UserWarning)
|
219 |
-
|
220 |
-
# class PatientSummarizerAgent:
|
221 |
-
# def __init__(self):
|
222 |
-
# # Device configuration
|
223 |
-
# self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
224 |
-
# if self.device == 'cuda':
|
225 |
-
# torch.cuda.empty_cache()
|
226 |
-
# print(f"✅ Using device for tensors: {self.device}")
|
227 |
-
|
228 |
-
# # Model configuration
|
229 |
-
# self.MODEL_NAME = "Falconsai/medical_summarization"
|
230 |
-
# try:
|
231 |
-
# self.tokenizer, self.model = self.load_model(self.MODEL_NAME, self.device)
|
232 |
-
# print(f"✅ Model '{self.MODEL_NAME}' loaded successfully.")
|
233 |
-
# except RuntimeError as e:
|
234 |
-
# print(f"❌ Failed to load model: {e}")
|
235 |
-
# exit(1)
|
236 |
-
|
237 |
-
# def load_model(self, model_name: str, device: str):
|
238 |
-
# """
|
239 |
-
# Loads the medical summarization model and tokenizer.
|
240 |
-
# """
|
241 |
-
# try:
|
242 |
-
# print(f"🔄 Loading model: {model_name}...")
|
243 |
-
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
244 |
-
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
245 |
-
# if device == 'cuda':
|
246 |
-
# model = model.half() # FP16 for GPU
|
247 |
-
# model.to(device)
|
248 |
-
# model.eval()
|
249 |
-
# print(f"✅ Model '{model_name}' loaded and set to evaluation mode.")
|
250 |
-
# return tokenizer, model
|
251 |
-
# except Exception as e:
|
252 |
-
# raise RuntimeError(f"Model loading failed: {str(e)}")
|
253 |
-
|
254 |
-
# def summarize_chunk(self, text: str) -> str:
|
255 |
-
# """
|
256 |
-
# Summarizes a single text chunk using the model.
|
257 |
-
# """
|
258 |
-
# try:
|
259 |
-
# inputs = self.tokenizer(
|
260 |
-
# text,
|
261 |
-
# return_tensors="pt",
|
262 |
-
# truncation=True,
|
263 |
-
# max_length=1024
|
264 |
-
# ).to(self.device)
|
265 |
-
|
266 |
-
# outputs = self.model.generate(
|
267 |
-
# **inputs,
|
268 |
-
# max_new_tokens=400,
|
269 |
-
# num_beams=4,
|
270 |
-
# temperature=0.7,
|
271 |
-
# early_stopping=True
|
272 |
-
# )
|
273 |
-
# summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
274 |
-
# return summary
|
275 |
-
# except Exception as e:
|
276 |
-
# return f"Error summarizing chunk: {str(e)}"
|
277 |
-
|
278 |
-
# def generate_clinical_summary(self, patient_data: dict) -> str:
|
279 |
-
# """
|
280 |
-
# End-to-end method to generate a comprehensive clinical summary.
|
281 |
-
# Mimics the logic flow of the reference script: narrative, assessment, pathway, formatting, and evaluation.
|
282 |
-
# """
|
283 |
-
# print("✨ Generating clinical summary using Falconsai/medical_summarization...")
|
284 |
-
# try:
|
285 |
-
# # Step 1: Build a chronological narrative from all encounters
|
286 |
-
# narrative_history = self.build_chronological_narrative(patient_data)
|
287 |
-
# print(f"\n--- Prompt Sent to Model (truncated) ---\n{fill(narrative_history, width=80)[:1000]}...\n")
|
288 |
-
|
289 |
-
# # Step 2: Summarize in chunks if needed
|
290 |
-
# chunks = self.chunk_text(narrative_history, chunk_size=1500)
|
291 |
-
# chunk_summaries = [self.summarize_chunk(chunk) for chunk in chunks]
|
292 |
-
# raw_summary_text = " ".join(chunk_summaries)
|
293 |
-
|
294 |
-
# print(f"\n--- Raw Model Output ---\n{fill(raw_summary_text, width=80)}\n")
|
295 |
-
|
296 |
-
# # Step 3: Format into structured clinical report
|
297 |
-
# formatted_report = self.format_clinical_output(raw_summary_text, patient_data)
|
298 |
-
|
299 |
-
# # Step 4: Simulated guideline evaluation
|
300 |
-
# evaluation_report = self.evaluate_summary_against_guidelines(raw_summary_text, patient_data)
|
301 |
-
|
302 |
-
# # Step 5: Combine final output
|
303 |
-
# final_output = (
|
304 |
-
# f"\n{'='*80}\n"
|
305 |
-
# f" FINAL CLINICAL SUMMARY REPORT\n"
|
306 |
-
# f"{'='*80}\n"
|
307 |
-
# f"{formatted_report}\n\n"
|
308 |
-
# f"{'='*80}\n"
|
309 |
-
# f" SIMULATED EVALUATION REPORT\n"
|
310 |
-
# f"{'='*80}\n"
|
311 |
-
# f"{evaluation_report}"
|
312 |
-
# )
|
313 |
-
# return final_output
|
314 |
-
|
315 |
-
# except Exception as e:
|
316 |
-
# print(f"❌ Error during summary generation: {e}")
|
317 |
-
# import traceback
|
318 |
-
# traceback.print_exc()
|
319 |
-
# return f"Error generating summary: {str(e)}"
|
320 |
-
|
321 |
-
# def build_chronological_narrative(self, patient_data: dict) -> str:
|
322 |
-
# """
|
323 |
-
# Builds a chronological narrative from multi-encounter patient history.
|
324 |
-
# """
|
325 |
-
# result = patient_data["result"]
|
326 |
-
# narrative = []
|
327 |
-
|
328 |
-
# # Past Medical History
|
329 |
-
# narrative.append(f"Past Medical History: {', '.join(result.get('past_medical_history', []))}.")
|
330 |
-
|
331 |
-
# # Social History
|
332 |
-
# social = result.get('social_history', 'Not specified.')
|
333 |
-
# narrative.append(f"Social History: {social}.")
|
334 |
-
|
335 |
-
# # Allergies
|
336 |
-
# allergies = ', '.join(result.get('allergies', ['None']))
|
337 |
-
# narrative.append(f"Allergies: {allergies}.")
|
338 |
-
|
339 |
-
# # Loop through encounters chronologically
|
340 |
-
# for enc in result.get("encounters", []):
|
341 |
-
# encounter_str = (
|
342 |
-
# f"Encounter on {enc['visit_date']}: "
|
343 |
-
# f"Chief Complaint: '{enc['chief_complaint']}'. "
|
344 |
-
# f"Symptoms: {enc.get('symptoms', 'None reported')}. "
|
345 |
-
# f"Diagnosis: {', '.join(enc['diagnosis'])}. "
|
346 |
-
# f"Doctor's Notes: {enc['dr_notes']}. "
|
347 |
-
# )
|
348 |
-
# if enc.get('vitals'):
|
349 |
-
# encounter_str += f"Vitals: {', '.join([f'{k}: {v}' for k, v in enc['vitals'].items()])}. "
|
350 |
-
# if enc.get('lab_results'):
|
351 |
-
# encounter_str += f"Labs: {', '.join([f'{k}: {v}' for k, v in enc['lab_results'].items()])}. "
|
352 |
-
# if enc.get('medications'):
|
353 |
-
# encounter_str += f"Medications: {', '.join(enc['medications'])}. "
|
354 |
-
# if enc.get('treatment'):
|
355 |
-
# encounter_str += f"Treatment: {enc['treatment']}."
|
356 |
-
# narrative.append(encounter_str)
|
357 |
-
|
358 |
-
# return "\n".join(narrative)
|
359 |
-
|
360 |
-
# def chunk_text(self, text: str, chunk_size: int = 1500) -> list:
|
361 |
-
# """
|
362 |
-
# Splits a long text into overlapping chunks for processing.
|
363 |
-
# """
|
364 |
-
# words = text.split()
|
365 |
-
# chunks = []
|
366 |
-
# for i in range(0, len(words), chunk_size):
|
367 |
-
# chunk = " ".join(words[i:i + chunk_size])
|
368 |
-
# chunks.append(chunk)
|
369 |
-
# return chunks
|
370 |
-
|
371 |
-
# def format_clinical_output(self, raw_summary: str, patient_data: dict) -> str:
|
372 |
-
# """
|
373 |
-
# Formats the raw AI-generated summary into a structured, doctor-friendly report.
|
374 |
-
# """
|
375 |
-
# result = patient_data["result"]
|
376 |
-
# last_encounter = result["encounters"][-1] if result.get("encounters") else result
|
377 |
-
|
378 |
-
# # Consolidate active problems
|
379 |
-
# all_diagnoses_raw = set(result.get('past_medical_history', []))
|
380 |
-
# for enc in result.get('encounters', []):
|
381 |
-
# all_diagnoses_raw.update(enc.get('diagnosis', []))
|
382 |
-
# cleaned_diagnoses = sorted({
|
383 |
-
# re.sub(r'\s*\([^)]*\)', '', dx).strip() for dx in all_diagnoses_raw
|
384 |
-
# })
|
385 |
-
|
386 |
-
# # Consolidate current medications
|
387 |
-
# all_medications = set()
|
388 |
-
# for enc in result.get('encounters', []):
|
389 |
-
# all_medications.update(enc.get('medications', []))
|
390 |
-
# current_meds = sorted(all_medications)
|
391 |
-
|
392 |
-
# # Report Header
|
393 |
-
# report = "\n==============================================\n"
|
394 |
-
# report += " CLINICAL SUMMARY REPORT\n"
|
395 |
-
# report += "==============================================\n"
|
396 |
-
# report += f"Generated On: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
|
397 |
-
|
398 |
-
# # Patient Overview
|
399 |
-
# report += "\n--- PATIENT OVERVIEW ---\n"
|
400 |
-
# report += f"Name: {result.get('patientname', 'Unknown')}\n"
|
401 |
-
# report += f"Patient ID: {result.get('patientnumber', 'Unknown')}\n"
|
402 |
-
# gender = result.get('gender', 'Unknown')
|
403 |
-
# report += f"Age/Sex: {result.get('agey', 'Unknown')} {gender[0] if gender != 'Unknown' else 'U'}\n"
|
404 |
-
# report += f"Allergies: {', '.join(result.get('allergies', ['None']))}\n"
|
405 |
-
|
406 |
-
# # Social History
|
407 |
-
# report += "\n--- SOCIAL HISTORY ---\n"
|
408 |
-
# report += fill(result.get('social_history', 'Not specified.'), width=80) + "\n"
|
409 |
-
|
410 |
-
# # Immediate Attention
|
411 |
-
# report += "\n--- IMMEDIATE ATTENTION (Most Recent Encounter) ---\n"
|
412 |
-
# report += f"Date of Event: {last_encounter.get('visit_date', 'Unknown')}\n"
|
413 |
-
# report += f"Chief Complaint: {last_encounter.get('chief_complaint', 'Not specified')}\n"
|
414 |
-
# if last_encounter.get('vitals'):
|
415 |
-
# vitals_str = ', '.join([f'{k}: {v}' for k, v in last_encounter['vitals'].items()])
|
416 |
-
# report += f"Vitals: {vitals_str}\n"
|
417 |
-
# critical_diagnoses = [
|
418 |
-
# dx for dx in last_encounter.get('diagnosis', [])
|
419 |
-
# if any(kw in dx.lower() for kw in ['acute', 'new onset', 'fall', 'afib', 'kidney injury'])
|
420 |
-
# ]
|
421 |
-
# if critical_diagnoses:
|
422 |
-
# report += f"Critical New Diagnoses: {', '.join(critical_diagnoses)}\n"
|
423 |
-
# report += f"Doctor's Notes: {last_encounter.get('dr_notes', 'N/A')}\n"
|
424 |
-
|
425 |
-
# # Active Problem List
|
426 |
-
# report += "\n--- ACTIVE PROBLEM LIST (Consolidated) ---\n"
|
427 |
-
# report += "\n".join(f"- {dx}" for dx in cleaned_diagnoses) + "\n"
|
428 |
-
|
429 |
-
# # Current Medications
|
430 |
-
# report += "\n--- CURRENT MEDICATION LIST (Consolidated) ---\n"
|
431 |
-
# report += "\n".join(f"- {med}" for med in current_meds) + "\n"
|
432 |
-
|
433 |
-
# # Procedures
|
434 |
-
# procedures = set()
|
435 |
-
# for enc in result.get('encounters', []):
|
436 |
-
# if 'treatment' in enc and 'PCI' in enc['treatment']:
|
437 |
-
# procedures.add(enc['treatment'])
|
438 |
-
# if procedures:
|
439 |
-
# report += "\n--- PROCEDURES & SURGERIES ---\n"
|
440 |
-
# report += "\n".join(f"- {proc}" for proc in sorted(procedures)) + "\n"
|
441 |
-
|
442 |
-
# # AI-Generated Narrative
|
443 |
-
# report += "\n--- AI-GENERATED CLINICAL NARRATIVE ---\n"
|
444 |
-
# report += fill(raw_summary, width=80) + "\n"
|
445 |
-
|
446 |
-
# # Placeholder sections if not in model output
|
447 |
-
# if "Assessment and Plan" not in raw_summary:
|
448 |
-
# report += "\n--- ASSESSMENT, PLAN AND NEXT STEPS (AI-Generated) ---\n"
|
449 |
-
# report += "The model did not generate a structured assessment and plan. Please review clinical context.\n"
|
450 |
-
|
451 |
-
# if "Clinical Pathway" not in raw_summary:
|
452 |
-
# report += "\n--- CLINICAL PATHWAY (AI-Generated) ---\n"
|
453 |
-
# report += "No clinical pathway was generated. Consider next steps based on active issues.\n"
|
454 |
-
|
455 |
-
# return report
|
456 |
-
|
457 |
-
# def evaluate_summary_against_guidelines(self, summary_text: str, patient_data: dict) -> str:
|
458 |
-
# """
|
459 |
-
# Simulated evaluation of summary against clinical guidelines.
|
460 |
-
# """
|
461 |
-
# result = patient_data["result"]
|
462 |
-
# last_enc = result["encounters"][-1] if result.get("encounters") else {}
|
463 |
-
|
464 |
-
# summary_lower = summary_text.lower()
|
465 |
-
# evaluation = (
|
466 |
-
# "\n==============================================\n"
|
467 |
-
# " AI SUMMARY EVALUATION & GUIDELINE CHECK\n"
|
468 |
-
# "==============================================\n"
|
469 |
-
# )
|
470 |
-
|
471 |
-
# # Keyword-based accuracy
|
472 |
-
# critical_keywords = [
|
473 |
-
# "fall", "dizziness", "atrial fibrillation", "afib", "rvr", "kidney", "ckd",
|
474 |
-
# "diabetes", "anticoagulation", "warfarin", "aspirin", "statin", "metformin",
|
475 |
-
# "gout", "angina", "pci", "bph", "hypertension", "metoprolol", "clopidogrel"
|
476 |
-
# ]
|
477 |
-
# found = [kw for kw in critical_keywords if kw in summary_lower]
|
478 |
-
# score = (len(found) / len(critical_keywords)) * 10
|
479 |
-
# evaluation += f"\n1. KEYWORD ACCURACY SCORE: {score:.1f}/10\n"
|
480 |
-
# evaluation += f" - Found {len(found)} out of {len(critical_keywords)} critical concepts.\n"
|
481 |
-
|
482 |
-
# # Guideline checks
|
483 |
-
# evaluation += "\n2. CLINICAL GUIDELINE COMMENTARY (SIMULATED):\n"
|
484 |
-
|
485 |
-
# has_afib = any("atrial fibrillation" in dx.lower() for dx in last_enc.get('diagnosis', []))
|
486 |
-
# on_anticoag = any("warfarin" in med.lower() or "apixaban" in med.lower() for med in last_enc.get('medications', []))
|
487 |
-
# if has_afib:
|
488 |
-
# evaluation += " - ✅ Patient with Atrial Fibrillation is on anticoagulation.\n" if on_anticoag \
|
489 |
-
# else " - ❌ Atrial Fibrillation present but no anticoagulant prescribed.\n"
|
490 |
-
|
491 |
-
# has_mi = any("myocardial infarction" in hx.lower() for hx in result.get('past_medical_history', []))
|
492 |
-
# on_statin = any("atorvastatin" in med.lower() or "statin" in med.lower() for med in last_enc.get('medications', []))
|
493 |
-
# if has_mi:
|
494 |
-
# evaluation += " - ✅ Patient with MI history is on statin therapy.\n" if on_statin \
|
495 |
-
# else " - ❌ Patient with MI history is not on statin therapy.\n"
|
496 |
-
|
497 |
-
# has_aki = any("acute kidney injury" in dx.lower() for dx in last_enc.get('diagnosis', []))
|
498 |
-
# acei_held = "hold" in last_enc.get('dr_notes', '').lower() and "lisinopril" in last_enc.get('dr_notes', '')
|
499 |
-
# if has_aki:
|
500 |
-
# evaluation += " - ✅ AKI noted and ACE inhibitor was appropriately held.\n" if acei_held \
|
501 |
-
# else " - ⚠️ AKI present but ACE inhibitor not documented as held.\n"
|
502 |
-
|
503 |
-
# evaluation += (
|
504 |
-
# "\nDisclaimer: This is a simulated evaluation and not a substitute for clinical judgment.\n"
|
505 |
-
# )
|
506 |
-
# return evaluation
|
507 |
-
|
508 |
-
|
509 |
-
# import datetime
|
510 |
-
# import torch
|
511 |
-
# import warnings
|
512 |
-
# import re
|
513 |
-
# import json
|
514 |
-
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
515 |
-
# from textwrap import fill
|
516 |
-
|
517 |
-
# # Suppress non-critical warnings
|
518 |
-
# warnings.filterwarnings("ignore", category=UserWarning)
|
519 |
-
|
520 |
-
# class PatientSummarizerAgent:
|
521 |
-
# def __init__(self, model_name: str = "Falconsai/medical_summarization", model_type: str = "seq2seq", device: str = None):
|
522 |
-
# self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
523 |
-
# if self.device == 'cuda':
|
524 |
-
# torch.cuda.empty_cache()
|
525 |
-
# print(f"✅ Using device for tensors: {self.device}")
|
526 |
-
|
527 |
-
# self.model_name = model_name
|
528 |
-
# if model_type != "seq2seq":
|
529 |
-
# raise ValueError(f"Unsupported model_type: {model_type}. Only 'seq2seq' is supported.")
|
530 |
-
# try:
|
531 |
-
# self.tokenizer, self.model = self.load_model(model_name, self.device)
|
532 |
-
# print(f"✅ Model '{model_name}' loaded successfully.")
|
533 |
-
# except RuntimeError as e:
|
534 |
-
# print(f"❌ Failed to load model: {e}")
|
535 |
-
# raise
|
536 |
-
|
537 |
-
# def load_model(self, model_name: str, device: str):
|
538 |
-
# try:
|
539 |
-
# print(f"🔄 Loading model: {model_name}...")
|
540 |
-
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
541 |
-
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
542 |
-
# if device == 'cuda':
|
543 |
-
# model = model.half()
|
544 |
-
# model.to(device)
|
545 |
-
# model.eval()
|
546 |
-
# if tokenizer.pad_token is None:
|
547 |
-
# tokenizer.pad_token = tokenizer.eos_token
|
548 |
-
# return tokenizer, model
|
549 |
-
# except Exception as e:
|
550 |
-
# raise RuntimeError(f"Model loading failed: {str(e)}")
|
551 |
-
|
552 |
-
# def summarize_chunk(self, text: str) -> str:
|
553 |
-
# try:
|
554 |
-
# inputs = self.tokenizer(
|
555 |
-
# text,
|
556 |
-
# return_tensors="pt",
|
557 |
-
# truncation=True,
|
558 |
-
# max_length=1024,
|
559 |
-
# padding=True
|
560 |
-
# ).to(self.device)
|
561 |
-
# outputs = self.model.generate(
|
562 |
-
# **inputs,
|
563 |
-
# max_new_tokens=400,
|
564 |
-
# num_beams=4,
|
565 |
-
# temperature=0.7,
|
566 |
-
# early_stopping=True
|
567 |
-
# )
|
568 |
-
# summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
569 |
-
# return summary
|
570 |
-
# except Exception as e:
|
571 |
-
# return f"Error summarizing chunk: {str(e)}"
|
572 |
-
|
573 |
-
# def parse_vitals(self, vitals_list):
|
574 |
-
# vitals_dict = {"BP": "N/A", "HR": "N/A", "Temp": "N/A", "SpO2": "N/A", "Height": "N/A", "Weight": "N/A", "BMI": "N/A"}
|
575 |
-
# if not isinstance(vitals_list, list):
|
576 |
-
# return vitals_dict
|
577 |
-
# for item in vitals_list:
|
578 |
-
# if not isinstance(item, dict):
|
579 |
-
# continue
|
580 |
-
# name = item.get("name", "").lower()
|
581 |
-
# value = item.get("value", "N/A")
|
582 |
-
# if "bp(sys)" in name:
|
583 |
-
# dia = next((v["value"] for v in vitals_list if "bp(dia)" in v.get("name", "").lower()), "N/A")
|
584 |
-
# vitals_dict["BP"] = f"{value}/{dia}" if value != "N/A" and dia != "N/A" else "N/A"
|
585 |
-
# elif "pulse" in name or "hr" in name:
|
586 |
-
# vitals_dict["HR"] = value
|
587 |
-
# elif "temp" in name:
|
588 |
-
# vitals_dict["Temp"] = value
|
589 |
-
# elif "spo2" in name or "o2 sat" in name:
|
590 |
-
# vitals_dict["SpO2"] = value
|
591 |
-
# elif "height" in name:
|
592 |
-
# vitals_dict["Height"] = value
|
593 |
-
# elif "weight" in name:
|
594 |
-
# vitals_dict["Weight"] = value
|
595 |
-
# elif "bmi" in name:
|
596 |
-
# vitals_dict["BMI"] = value
|
597 |
-
# return vitals_dict
|
598 |
-
|
599 |
-
# def build_chronological_narrative(self, patient_data: dict) -> str:
|
600 |
-
# narrative = []
|
601 |
-
# result = patient_data.get("result", {})
|
602 |
-
# flattened = patient_data.get("flattened", [])
|
603 |
-
|
604 |
-
# # Extract basic patient info
|
605 |
-
# narrative.append(f"Patient ID: {result.get('patientnumber', 'Unknown')}")
|
606 |
-
# narrative.append(f"Age/Sex: {result.get('agey', 'Unknown')} {result.get('gender', 'Unknown')}")
|
607 |
-
# narrative.append(f"Allergies: {', '.join(result.get('allergies', ['None known']))}")
|
608 |
-
# narrative.append(f"Social History: {result.get('social_history', 'Not specified')}")
|
609 |
-
# narrative.append(f"Past Medical History: {', '.join(result.get('past_medical_history', ['None']))}")
|
610 |
-
|
611 |
-
# # Parse Chartsummarydtl from flattened
|
612 |
-
# encounters = []
|
613 |
-
# for item in flattened:
|
614 |
-
# if item.startswith("Chartsummarydtl:"):
|
615 |
-
# try:
|
616 |
-
# chart_data_str = item.split("Chartsummarydtl:")[1].strip()
|
617 |
-
# chart_data = json.loads(chart_data_str)
|
618 |
-
# if isinstance(chart_data, list):
|
619 |
-
# encounters.extend(chart_data)
|
620 |
-
# except (IndexError, json.JSONDecodeError, ValueError) as e:
|
621 |
-
# print(f"Failed to parse Chartsummarydtl: {e}")
|
622 |
-
# continue
|
623 |
-
|
624 |
-
# if not encounters:
|
625 |
-
# narrative.append("No encounter data available.")
|
626 |
-
# return "\n".join(narrative)
|
627 |
-
|
628 |
-
# # Sort encounters by date
|
629 |
-
# encounters = sorted(encounters, key=lambda x: x.get('chartdate', ''), reverse=True)
|
630 |
-
|
631 |
-
# for enc in encounters:
|
632 |
-
# vitals = self.parse_vitals(enc.get('vitals', []))
|
633 |
-
# encounter_str = f"Encounter on {enc.get('chartdate', 'Unknown')}: "
|
634 |
-
# encounter_str += f"Chief Complaint: {result.get('chief_complaint', 'Not specified')}. "
|
635 |
-
# encounter_str += f"Vitals: BP: {vitals['BP']}, HR: {vitals['HR']}, SpO2: {vitals['SpO2']}, Temp: {vitals['Temp']}, Height: {vitals['Height']}, Weight: {vitals['Weight']}, BMI: {vitals['BMI']}. "
|
636 |
-
# encounter_str += f"Diagnosis: {', '.join(enc.get('diagnosis', ['None']))}. "
|
637 |
-
# # Deduplicate medications
|
638 |
-
# medications = list(set(enc.get('medications', ['None'])))
|
639 |
-
# encounter_str += f"Medications: {', '.join(med.strip(' ||') for med in medications)}. "
|
640 |
-
# encounter_str += f"Lab Tests: {', '.join(enc.get('labtests', ['None']))}. "
|
641 |
-
# radiology = [r['name'] for r in enc.get('radiologyorders', [])]
|
642 |
-
# encounter_str += f"Radiology Orders: {', '.join(radiology) if radiology else 'None'}. "
|
643 |
-
# encounter_str += f"Allergies: {', '.join(enc.get('allergies', ['None']))}."
|
644 |
-
# narrative.append(encounter_str)
|
645 |
-
|
646 |
-
# return "\n".join(narrative)
|
647 |
-
|
648 |
-
# def chunk_text(self, text: str, chunk_size: int = 1500) -> list:
|
649 |
-
# words = text.split()
|
650 |
-
# chunks = []
|
651 |
-
# for i in range(0, len(words), chunk_size):
|
652 |
-
# chunk = " ".join(words[i:i + chunk_size])
|
653 |
-
# chunks.append(chunk)
|
654 |
-
# return chunks if chunks else [text]
|
655 |
-
|
656 |
-
# def generate_clinical_summary(self, patient_data: dict) -> str:
|
657 |
-
# print(f"✨ Generating clinical summary using model: {self.model_name}...")
|
658 |
-
# try:
|
659 |
-
# narrative_history = self.build_chronological_narrative(patient_data)
|
660 |
-
# print(f"\n--- Prompt Sent to Model (truncated) ---\n{fill(narrative_history, width=80)[:1000]}...")
|
661 |
-
|
662 |
-
# chunks = self.chunk_text(narrative_history, chunk_size=1500)
|
663 |
-
# chunk_summaries = [self.summarize_chunk(chunk) for chunk in chunks]
|
664 |
-
# raw_summary_text = " ".join(chunk_summaries)
|
665 |
-
# print(f"\n--- Raw Model Output ---\n{fill(raw_summary_text, width=80)}")
|
666 |
-
|
667 |
-
# formatted_report = self.format_clinical_output(raw_summary_text, patient_data)
|
668 |
-
# evaluation_report = self.evaluate_summary_against_guidelines(raw_summary_text, patient_data)
|
669 |
-
|
670 |
-
# final_output = (
|
671 |
-
# f"\n{'='*80}\n"
|
672 |
-
# f" FINAL CLINICAL SUMMARY REPORT\n"
|
673 |
-
# f"{'='*80}\n"
|
674 |
-
# f"{formatted_report}\n"
|
675 |
-
# f"{'='*80}\n"
|
676 |
-
# f" SIMULATED EVALUATION REPORT\n"
|
677 |
-
# f"{'='*80}\n"
|
678 |
-
# f"{evaluation_report}"
|
679 |
-
# )
|
680 |
-
# return final_output
|
681 |
-
# except Exception as e:
|
682 |
-
# print(f"❌ Error during summary generation: {e}")
|
683 |
-
# import traceback
|
684 |
-
# traceback.print_exc()
|
685 |
-
# return f"Error generating summary: {str(e)}"
|
686 |
-
|
687 |
-
# def format_clinical_output(self, raw_summary: str, patient_data: dict) -> str:
|
688 |
-
# result = patient_data.get("result", {})
|
689 |
-
# flattened = patient_data.get("flattened", [])
|
690 |
-
# encounters = []
|
691 |
-
# for item in flattened:
|
692 |
-
# if item.startswith("Chartsummarydtl:"):
|
693 |
-
# try:
|
694 |
-
# chart_data = json.loads(item.split("Chartsummarydtl:")[1].strip())
|
695 |
-
# if isinstance(chart_data, list):
|
696 |
-
# encounters.extend(chart_data)
|
697 |
-
# except (IndexError, json.JSONDecodeError, ValueError) as e:
|
698 |
-
# print(f"Failed to parse Chartsummarydtl: {e}")
|
699 |
-
# continue
|
700 |
-
# last_encounter = sorted(encounters, key=lambda x: x.get('chartdate', ''), reverse=True)[0] if encounters else {}
|
701 |
-
|
702 |
-
# all_diagnoses = set(result.get('past_medical_history', []))
|
703 |
-
# all_medications = set()
|
704 |
-
# for enc in encounters:
|
705 |
-
# all_diagnoses.update(enc.get('diagnosis', []))
|
706 |
-
# all_medications.update(med.strip(' ||') for med in enc.get('medications', []))
|
707 |
-
# cleaned_diagnoses = sorted({re.sub(r'\s*\([^)]*\)', '', dx).strip() for dx in all_diagnoses})
|
708 |
-
# current_meds = sorted(all_medications)
|
709 |
-
|
710 |
-
# report = (
|
711 |
-
# "\n==============================================\n"
|
712 |
-
# " CLINICAL SUMMARY REPORT\n"
|
713 |
-
# "==============================================\n"
|
714 |
-
# f"Generated On: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
|
715 |
-
# )
|
716 |
-
# report += f"Name: {result.get('patientname', 'Unknown')}\n"
|
717 |
-
# report += f"Patient ID: {result.get('patientnumber', 'Unknown')}\n"
|
718 |
-
# gender = result.get('gender', 'Unknown')
|
719 |
-
# report += f"Age/Sex: {result.get('agey', 'Unknown')} {gender[0] if gender != 'Unknown' else 'U'}\n"
|
720 |
-
# report += f"Allergies: {', '.join(result.get('allergies', ['None known']))}\n"
|
721 |
-
|
722 |
-
# report += "\n--- SOCIAL HISTORY ---\n"
|
723 |
-
# report += fill(result.get('social_history', 'Not specified.'), width=80) + "\n"
|
724 |
-
|
725 |
-
# report += "\n--- IMMEDIATE ATTENTION (Most Recent Encounter) ---\n"
|
726 |
-
# report += f"Date of Event: {last_encounter.get('chartdate', 'Unknown')}\n"
|
727 |
-
# report += f"Chief Complaint: {result.get('chief_complaint', 'Not specified')}\n"
|
728 |
-
# if last_encounter.get('vitals'):
|
729 |
-
# vitals = self.parse_vitals(last_encounter['vitals'])
|
730 |
-
# vitals_str = ', '.join([f"{k}: {v}" for k, v in vitals.items()])
|
731 |
-
# report += f"Vitals: {vitals_str}\n"
|
732 |
-
# critical_diagnoses = [
|
733 |
-
# dx for dx in last_encounter.get('diagnosis', [])
|
734 |
-
# if any(kw in dx.lower() for kw in ['acute', 'new onset', 'fall', 'afib', 'kidney injury'])
|
735 |
-
# ]
|
736 |
-
# if critical_diagnoses:
|
737 |
-
# report += f"Critical New Diagnoses: {', '.join(critical_diagnoses)}\n"
|
738 |
-
# report += f"Doctor's Notes: {last_encounter.get('dr_notes', 'N/A')}\n"
|
739 |
-
|
740 |
-
# report += "\n--- ACTIVE PROBLEM LIST (Consolidated) ---\n"
|
741 |
-
# report += "\n".join(f"- {dx}" for dx in cleaned_diagnoses) + "\n" if cleaned_diagnoses else "- None\n"
|
742 |
-
|
743 |
-
# report += "\n--- CURRENT MEDICATION LIST (Consolidated) ---\n"
|
744 |
-
# report += "\n".join(f"- {med}" for med in current_meds) + "\n" if current_meds else "- None\n"
|
745 |
-
|
746 |
-
# report += "\n--- AI-GENERATED CLINICAL NARRATIVE ---\n"
|
747 |
-
# report += fill(raw_summary, width=80) + "\n"
|
748 |
-
|
749 |
-
# report += "\n--- ASSESSMENT, PLAN AND NEXT STEPS (AI-Generated) ---\n"
|
750 |
-
# report += "The model did not generate a structured assessment and plan. Please review clinical context.\n"
|
751 |
-
|
752 |
-
# report += "\n--- CLINICAL PATHWAY (AI-Generated) ---\n"
|
753 |
-
# report += "No clinical pathway was generated. Consider next steps based on active issues.\n"
|
754 |
-
|
755 |
-
# return report
|
756 |
-
|
757 |
-
# def evaluate_summary_against_guidelines(self, summary_text: str, patient_data: dict) -> str:
|
758 |
-
# result = patient_data.get("result", {})
|
759 |
-
# flattened = patient_data.get("flattened", [])
|
760 |
-
# encounters = []
|
761 |
-
# for item in flattened:
|
762 |
-
# if item.startswith("Chartsummarydtl:"):
|
763 |
-
# try:
|
764 |
-
# chart_data = json.loads(item.split("Chartsummarydtl:")[1].strip())
|
765 |
-
# if isinstance(chart_data, list):
|
766 |
-
# encounters.extend(chart_data)
|
767 |
-
# except (IndexError, json.JSONDecodeError, ValueError):
|
768 |
-
# continue
|
769 |
-
# last_enc = sorted(encounters, key=lambda x: x.get('chartdate', ''), reverse=True)[0] if encounters else {}
|
770 |
-
|
771 |
-
# summary_lower = summary_text.lower()
|
772 |
-
# evaluation = (
|
773 |
-
# "\n==============================================\n"
|
774 |
-
# " AI SUMMARY EVALUATION & GUIDELINE CHECK\n"
|
775 |
-
# "==============================================\n"
|
776 |
-
# )
|
777 |
-
|
778 |
-
# critical_keywords = [
|
779 |
-
# "metrogyl", "rantac", "ultrasound", "egg allergy", "blood pressure", "pulse", "spo2",
|
780 |
-
# "bmi", "temperature", "pain", "height", "weight"
|
781 |
-
# ]
|
782 |
-
# found = [kw for kw in critical_keywords if kw in summary_lower]
|
783 |
-
# score = (len(found) / len(critical_keywords)) * 10
|
784 |
-
# evaluation += f"\n1. KEYWORD ACCURACY SCORE: {score:.1f}/10\n"
|
785 |
-
# evaluation += f" - Found {len(found)} out of {len(critical_keywords)} critical concepts.\n"
|
786 |
-
|
787 |
-
# evaluation += "\n2. CLINICAL GUIDELINE COMMENTARY (SIMULATED):\n"
|
788 |
-
# has_allergy = any("egg allergy" in a.lower() for a in last_enc.get('allergies', []))
|
789 |
-
# if has_allergy:
|
790 |
-
# evaluation += " - ✅ Egg allergy noted in the patient record.\n"
|
791 |
-
|
792 |
-
# has_medications = bool(last_enc.get('medications', []))
|
793 |
-
# if has_medications:
|
794 |
-
# medications = list(set(med.strip(' ||') for med in last_enc.get('medications', [])))
|
795 |
-
# evaluation += f" - ✅ Medications prescribed: {', '.join(medications)}.\n"
|
796 |
-
# else:
|
797 |
-
# evaluation += " - ⚠️ No medications prescribed in the latest encounter.\n"
|
798 |
-
|
799 |
-
# has_radiology = bool(last_enc.get('radiologyorders', []))
|
800 |
-
# if has_radiology:
|
801 |
-
# radiology = [r['name'] for r in last_enc['radiologyorders']]
|
802 |
-
# evaluation += f" - ✅ Radiology orders issued: {', '.join(radiology)}.\n"
|
803 |
-
|
804 |
-
# evaluation += "\nDisclaimer: This is a simulated evaluation and not a substitute for clinical judgment.\n"
|
805 |
-
# return evaluation
|
806 |
-
|
807 |
-
|
808 |
-
# import torch
|
809 |
-
# import warnings
|
810 |
-
# import logging
|
811 |
-
# import json
|
812 |
-
# import requests
|
813 |
-
# from typing import List, Dict, Union, Optional
|
814 |
-
# from flask import Flask, request, jsonify
|
815 |
-
# from transformers import (
|
816 |
-
# AutoTokenizer,
|
817 |
-
# AutoModelForSeq2SeqLM,
|
818 |
-
# AutoModelForCausalLM,
|
819 |
-
# AutoConfig
|
820 |
-
# )
|
821 |
-
|
822 |
-
# # -----------------------------
|
823 |
-
# # Setup
|
824 |
-
# # -----------------------------
|
825 |
-
# warnings.filterwarnings("ignore", category=UserWarning)
|
826 |
-
# logging.basicConfig(level=logging.INFO)
|
827 |
-
# logger = logging.getLogger(__name__)
|
828 |
-
|
829 |
-
# # -----------------------------
|
830 |
-
# # Enhanced Patient Summarizer Agent
|
831 |
-
# # -----------------------------
|
832 |
-
# class PatientSummarizerAgent:
|
833 |
-
# def __init__(
|
834 |
-
# self,
|
835 |
-
# model_name: str = "Falconsai/medical_summarization",
|
836 |
-
# model_type: Optional[str] = None,
|
837 |
-
# device: Optional[str] = None,
|
838 |
-
# max_input_tokens: int = 2048,
|
839 |
-
# max_output_tokens: int = 512
|
840 |
-
# ):
|
841 |
-
# self.model_name = model_name
|
842 |
-
# self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
843 |
-
# self.max_input_tokens = max_input_tokens
|
844 |
-
# self.max_output_tokens = max_output_tokens
|
845 |
-
# logger.info(f"Loading model '{model_name}' on {self.device}...")
|
846 |
-
|
847 |
-
# config = AutoConfig.from_pretrained(model_name)
|
848 |
-
# if config.model_type in ["t5", "bart", "mbart", "longt5", "led"]:
|
849 |
-
# self.model_type = "seq2seq"
|
850 |
-
# self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
|
851 |
-
# else:
|
852 |
-
# self.model_type = "causal"
|
853 |
-
# self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
|
854 |
-
|
855 |
-
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
856 |
-
# if self.tokenizer.pad_token is None:
|
857 |
-
# self.tokenizer.pad_token = self.tokenizer.eos_token
|
858 |
-
|
859 |
-
# logger.info("Model loaded successfully.")
|
860 |
-
|
861 |
-
# def _parse_patient_data(self, data: Union[List[str], Dict]) -> Dict:
|
862 |
-
# """Convert flattened list or dict to key-value dict."""
|
863 |
-
# if isinstance(data, dict):
|
864 |
-
# return data
|
865 |
-
# elif isinstance(data, list):
|
866 |
-
# patient_dict = {}
|
867 |
-
# for entry in data:
|
868 |
-
# if ":" in entry:
|
869 |
-
# parts = entry.split(":", 1)
|
870 |
-
# key = parts[0].strip()
|
871 |
-
# value = parts[1].strip() if len(parts) > 1 else "N/A"
|
872 |
-
# patient_dict[key] = value
|
873 |
-
# return patient_dict
|
874 |
-
# else:
|
875 |
-
# raise ValueError("Patient data must be a dict or a list of 'key: value' strings.")
|
876 |
-
|
877 |
-
# def _build_prompt(self, patient_info: Dict) -> str:
|
878 |
-
# """Build a rich, instructive prompt for clinical reasoning."""
|
879 |
-
# patient_details = "\n".join([f"{k}: {v}" for k, v in patient_info.items() if v not in ["N/A", ""]])
|
880 |
-
|
881 |
-
# prompt = (
|
882 |
-
# "You are an AI clinical assistant. Analyze the patient data below and generate a structured, "
|
883 |
-
# "professional summary for use by physicians. Focus on:\n"
|
884 |
-
# "1. Patient Overview (age, gender, key identifiers)\n"
|
885 |
-
# "2. Key Medical History (PMH, allergies, medications)\n"
|
886 |
-
# "3. Vital Sign Trends (BP, HR, weight, SpO2) — highlight changes over time\n"
|
887 |
-
# "4. Assessment (possible conditions based on data)\n"
|
888 |
-
# "5. Recommendations (labs, imaging, referrals, medication review)\n\n"
|
889 |
-
|
890 |
-
# "Rules:\n"
|
891 |
-
# "- Only use information provided. Do not invent details.\n"
|
892 |
-
# "- If a value is increasing (e.g., BP), flag it as a concern.\n"
|
893 |
-
# "- If a medication is repeated across visits, assume chronic use.\n"
|
894 |
-
# "- If a test (e.g., ultrasound) is ordered repeatedly without result, recommend follow-up.\n"
|
895 |
-
# "- Use concise, professional language.\n\n"
|
896 |
-
|
897 |
-
# "--- PATIENT DATA ---\n"
|
898 |
-
# f"{patient_details}\n\n"
|
899 |
-
|
900 |
-
# "Provide the summary in this format:\n"
|
901 |
-
# "Patient Overview:\n"
|
902 |
-
# "Medical History:\n"
|
903 |
-
# "Vital Trends:\n"
|
904 |
-
# "Assessment:\n"
|
905 |
-
# "Recommendations:"
|
906 |
-
# )
|
907 |
-
# return prompt
|
908 |
-
|
909 |
-
# def generate_clinical_summary(self, patient_data: Union[List[str], Dict]) -> str:
|
910 |
-
# """Generate a clinical summary with error handling."""
|
911 |
-
# try:
|
912 |
-
# patient_info = self._parse_patient_data(patient_data)
|
913 |
-
# prompt = self._build_prompt(patient_info)
|
914 |
-
|
915 |
-
# inputs = self.tokenizer(
|
916 |
-
# prompt,
|
917 |
-
# return_tensors="pt",
|
918 |
-
# truncation=True,
|
919 |
-
# max_length=self.max_input_tokens,
|
920 |
-
# padding=True
|
921 |
-
# ).to(self.device)
|
922 |
-
|
923 |
-
# if self.model_type == "seq2seq":
|
924 |
-
# outputs = self.model.generate(
|
925 |
-
# **inputs,
|
926 |
-
# max_new_tokens=self.max_output_tokens,
|
927 |
-
# num_beams=4,
|
928 |
-
# temperature=0.7,
|
929 |
-
# top_p=0.9,
|
930 |
-
# do_sample=True
|
931 |
-
# )
|
932 |
-
# else:
|
933 |
-
# outputs = self.model.generate(
|
934 |
-
# **inputs,
|
935 |
-
# max_new_tokens=self.max_output_tokens,
|
936 |
-
# temperature=0.7,
|
937 |
-
# top_p=0.9,
|
938 |
-
# do_sample=True,
|
939 |
-
# pad_token_id=self.tokenizer.eos_token_id
|
940 |
-
# )
|
941 |
-
|
942 |
-
# summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
943 |
-
# return summary.strip()
|
944 |
-
|
945 |
-
# except Exception as e:
|
946 |
-
# logger.error(f"Error during summary generation: {str(e)}")
|
947 |
-
# return "Error: Failed to generate clinical summary due to model processing error."
|
948 |
-
|
949 |
-
# # agent.py partially working
|
950 |
-
# import torch
|
951 |
-
# import warnings
|
952 |
-
# import logging
|
953 |
-
# from typing import List, Dict, Union, Optional
|
954 |
-
# from transformers import (
|
955 |
-
# AutoTokenizer,
|
956 |
-
# AutoModelForSeq2SeqLM,
|
957 |
-
# AutoModelForCausalLM,
|
958 |
-
# AutoConfig
|
959 |
-
# )
|
960 |
-
|
961 |
-
# warnings.filterwarnings("ignore", category=UserWarning)
|
962 |
-
# logging.basicConfig(level=logging.INFO)
|
963 |
-
# logger = logging.getLogger(__name__)
|
964 |
-
|
965 |
-
# class PatientSummarizerAgent:
|
966 |
-
# def __init__(
|
967 |
-
# self,
|
968 |
-
# model_name: str = "Falconsai/medical_summarization",
|
969 |
-
# device: Optional[str] = None,
|
970 |
-
# max_input_tokens: int = 2048,
|
971 |
-
# max_output_tokens: int = 512
|
972 |
-
# ):
|
973 |
-
# self.model_name = model_name
|
974 |
-
# self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
975 |
-
# self.max_input_tokens = max_input_tokens
|
976 |
-
# self.max_output_tokens = max_output_tokens
|
977 |
-
|
978 |
-
# logger.info(f"Loading model '{model_name}' on {self.device}...")
|
979 |
-
|
980 |
-
# try:
|
981 |
-
# config = AutoConfig.from_pretrained(model_name)
|
982 |
-
# if config.model_type in ["t5", "bart", "mbart", "longt5", "led"]:
|
983 |
-
# self.model_type = "seq2seq"
|
984 |
-
# self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
|
985 |
-
# else:
|
986 |
-
# self.model_type = "causal"
|
987 |
-
# self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
|
988 |
-
|
989 |
-
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
990 |
-
# if self.tokenizer.pad_token is None:
|
991 |
-
# self.tokenizer.pad_token = self.tokenizer.eos_token
|
992 |
-
# if self.tokenizer.sep_token is None:
|
993 |
-
# self.tokenizer.sep_token = self.tokenizer.eos_token
|
994 |
-
|
995 |
-
# logger.info(f"Model '{model_name}' loaded successfully as {config.model_type}.")
|
996 |
-
|
997 |
-
# except Exception as e:
|
998 |
-
# logger.critical(f"Model loading failed: {str(e)}", exc_info=True)
|
999 |
-
# raise RuntimeError(f"Model loading failed: {str(e)}")
|
1000 |
-
|
1001 |
-
# def _parse_patient_data(self, data: Union[List[str], Dict]) -> Dict:
|
1002 |
-
# """Safely parse flattened list into dict without overwriting or nesting issues."""
|
1003 |
-
# if isinstance(data, dict):
|
1004 |
-
# return data
|
1005 |
-
# elif isinstance(data, list):
|
1006 |
-
# patient_dict = {}
|
1007 |
-
# for entry in data:
|
1008 |
-
# if not isinstance(entry, str) or ":" not in entry:
|
1009 |
-
# continue
|
1010 |
-
# key, *value_parts = entry.split(":", 1)
|
1011 |
-
# value = value_parts[0].strip() if value_parts else "N/A"
|
1012 |
-
# key = key.strip()
|
1013 |
-
|
1014 |
-
# # Skip if value is a dict repr (like "{...}")
|
1015 |
-
# if value.startswith("{") and value.endswith("}"):
|
1016 |
-
# continue
|
1017 |
-
|
1018 |
-
# if key in patient_dict:
|
1019 |
-
# if isinstance(patient_dict[key], list):
|
1020 |
-
# patient_dict[key].append(value)
|
1021 |
-
# else:
|
1022 |
-
# patient_dict[key] = [patient_dict[key], value]
|
1023 |
-
# else:
|
1024 |
-
# patient_dict[key] = value
|
1025 |
-
|
1026 |
-
# # Deduplicate and clean
|
1027 |
-
# cleaned = {}
|
1028 |
-
# for k, v in patient_dict.items():
|
1029 |
-
# if isinstance(v, list):
|
1030 |
-
# unique_vals = list({x for x in v if x not in ["N/A", "Unknown", ""]})
|
1031 |
-
# cleaned[k] = ", ".join(unique_vals) if unique_vals else "N/A"
|
1032 |
-
# else:
|
1033 |
-
# cleaned[k] = v if v not in ["", "Unknown", "N/A"] else "N/A"
|
1034 |
-
# return cleaned
|
1035 |
-
# else:
|
1036 |
-
# raise ValueError("Unsupported data format")
|
1037 |
-
|
1038 |
-
# def _build_prompt(self, patient_info: Dict) -> str:
|
1039 |
-
# """Build a dynamic, instructive prompt for clinical reasoning."""
|
1040 |
-
# non_na_items = [
|
1041 |
-
# f"{k}: {v}" for k, v in patient_info.items()
|
1042 |
-
# if v not in ["N/A", "Unknown", "None known", "Stable", "Not specified", "", "None"]
|
1043 |
-
# and isinstance(v, str)
|
1044 |
-
# and len(v.strip()) > 1
|
1045 |
-
# ]
|
1046 |
-
# patient_details = "\n".join(non_na_items)
|
1047 |
-
|
1048 |
-
# prompt = (
|
1049 |
-
# "You are an expert AI clinical assistant. Analyze the following patient data and generate a structured, "
|
1050 |
-
# "concise, and actionable summary for physicians. Use only the provided information.\n\n"
|
1051 |
-
# "Include:\n"
|
1052 |
-
# "1. Patient Overview (age, gender, ID)\n"
|
1053 |
-
# "2. Medical History (allergies, medications, diagnosis)\n"
|
1054 |
-
# "3. Vital Trends (BP, HR, SpO2, weight) — highlight changes over last 3 visits\n"
|
1055 |
-
# "4. Test Trends (labs, imaging) — flag repeated orders without results\n"
|
1056 |
-
# "5. Assessment (possible conditions)\n"
|
1057 |
-
# "6. Recommendations (labs, imaging, referrals, med review)\n\n"
|
1058 |
-
# "Rules:\n"
|
1059 |
-
# "- Do not invent any information.\n"
|
1060 |
-
# "- If BP is rising (e.g., 132/85 → 135/95), flag it.\n"
|
1061 |
-
# "- If a medication appears in ≥2 visits, assume chronic use.\n"
|
1062 |
-
# "- If a test is repeated without result, recommend follow-up.\n"
|
1063 |
-
# "- Use professional, concise language.\n\n"
|
1064 |
-
# "--- PATIENT DATA ---\n"
|
1065 |
-
# f"{patient_details}\n\n"
|
1066 |
-
# "Provide the summary in this format:\n"
|
1067 |
-
# "Patient Overview:\n"
|
1068 |
-
# "Medical History:\n"
|
1069 |
-
# "Vital Trends:\n"
|
1070 |
-
# "Test Trends:\n"
|
1071 |
-
# "Assessment:\n"
|
1072 |
-
# "Recommendations:"
|
1073 |
-
# )
|
1074 |
-
# return prompt
|
1075 |
-
|
1076 |
-
# def generate_clinical_summary(self, patient_data: Union[List[str], Dict]) -> str:
|
1077 |
-
# """Generate a clinical summary with full error resilience."""
|
1078 |
-
# try:
|
1079 |
-
# patient_info = self._parse_patient_data(patient_data)
|
1080 |
-
# prompt = self._build_prompt(patient_info)
|
1081 |
-
|
1082 |
-
# inputs = self.tokenizer(
|
1083 |
-
# prompt,
|
1084 |
-
# return_tensors="pt",
|
1085 |
-
# truncation=True,
|
1086 |
-
# max_length=self.max_input_tokens,
|
1087 |
-
# padding=True
|
1088 |
-
# ).to(self.device)
|
1089 |
-
|
1090 |
-
# if self.model_type == "seq2seq":
|
1091 |
-
# outputs = self.model.generate(
|
1092 |
-
# **inputs,
|
1093 |
-
# max_new_tokens=self.max_output_tokens,
|
1094 |
-
# num_beams=4,
|
1095 |
-
# temperature=0.7,
|
1096 |
-
# top_p=0.9,
|
1097 |
-
# do_sample=True
|
1098 |
-
# )
|
1099 |
-
# else:
|
1100 |
-
# outputs = self.model.generate(
|
1101 |
-
# **inputs,
|
1102 |
-
# max_new_tokens=self.max_output_tokens,
|
1103 |
-
# temperature=0.7,
|
1104 |
-
# top_p=0.9,
|
1105 |
-
# do_sample=True,
|
1106 |
-
# pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
|
1107 |
-
# eos_token_id=self.tokenizer.eos_token_id
|
1108 |
-
# )
|
1109 |
-
|
1110 |
-
# summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
1111 |
-
# return summary.strip()
|
1112 |
-
|
1113 |
-
# except Exception as e:
|
1114 |
-
# logger.error(f"Summary generation failed: {str(e)}", exc_info=True)
|
1115 |
-
# return (
|
1116 |
-
# "Error: Failed to generate clinical summary. "
|
1117 |
-
# "Please check model and input data."
|
1118 |
-
# )
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
# # agent.py working partially
|
1123 |
-
# import torch
|
1124 |
-
# import warnings
|
1125 |
-
# import logging
|
1126 |
-
# from typing import List, Dict, Union
|
1127 |
-
# from transformers import (
|
1128 |
-
# AutoTokenizer,
|
1129 |
-
# AutoModelForSeq2SeqLM,
|
1130 |
-
# AutoModelForCausalLM,
|
1131 |
-
# AutoConfig
|
1132 |
-
# )
|
1133 |
-
|
1134 |
-
# warnings.filterwarnings("ignore", category=UserWarning)
|
1135 |
-
# logging.basicConfig(level=logging.INFO)
|
1136 |
-
# logger = logging.getLogger(__name__)
|
1137 |
-
|
1138 |
-
# class PatientSummarizerAgent:
|
1139 |
-
# def __init__(
|
1140 |
-
# self,
|
1141 |
-
# model_name: str = "Falconsai/medical_summarization",
|
1142 |
-
# device: str = None,
|
1143 |
-
# max_input_tokens: int = 2048,
|
1144 |
-
# max_output_tokens: int = 512
|
1145 |
-
# ):
|
1146 |
-
# self.model_name = model_name
|
1147 |
-
# self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
1148 |
-
# self.max_input_tokens = max_input_tokens
|
1149 |
-
# self.max_output_tokens = max_output_tokens
|
1150 |
-
|
1151 |
-
# logger.info(f"Loading model '{model_name}' on {self.device}...")
|
1152 |
-
|
1153 |
-
# try:
|
1154 |
-
# config = AutoConfig.from_pretrained(model_name)
|
1155 |
-
# if config.model_type in ["t5", "bart", "mbart"]:
|
1156 |
-
# self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
|
1157 |
-
# self.model_type = "seq2seq"
|
1158 |
-
# else:
|
1159 |
-
# self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
|
1160 |
-
# self.model_type = "causal"
|
1161 |
-
|
1162 |
-
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
1163 |
-
# if self.tokenizer.pad_token is None:
|
1164 |
-
# self.tokenizer.pad_token = self.tokenizer.eos_token
|
1165 |
-
|
1166 |
-
# logger.info(f"Model '{model_name}' loaded successfully.")
|
1167 |
-
# except Exception as e:
|
1168 |
-
# logger.critical(f"Model loading failed: {str(e)}", exc_info=True)
|
1169 |
-
# raise RuntimeError(f"Model loading failed: {str(e)}")
|
1170 |
-
|
1171 |
-
# def generate_clinical_summary(self, patient_data: Union[List[str], str]) -> str:
|
1172 |
-
# """
|
1173 |
-
# Generate clinical summary directly from flattened list.
|
1174 |
-
# No parsing back to dict — just join into clean text.
|
1175 |
-
# """
|
1176 |
-
# try:
|
1177 |
-
# # Convert to single string
|
1178 |
-
# if isinstance(patient_data, list):
|
1179 |
-
# # Join all lines into one clean string
|
1180 |
-
# patient_text = "\n".join(
|
1181 |
-
# line.strip() for line in patient_data if line.strip()
|
1182 |
-
# )
|
1183 |
-
# elif isinstance(patient_data, str):
|
1184 |
-
# patient_text = patient_data
|
1185 |
-
# else:
|
1186 |
-
# return "Error: Invalid input type."
|
1187 |
-
|
1188 |
-
# # Build prompt
|
1189 |
-
# prompt = f"""
|
1190 |
-
# You are an expert AI clinical assistant. Analyze the patient data below and generate a structured,
|
1191 |
-
# concise, and actionable summary for physicians. Use only the provided information.
|
1192 |
-
|
1193 |
-
# Include:
|
1194 |
-
# 1. Patient Overview (age, gender, ID)
|
1195 |
-
# 2. Medical History (allergies, medications, diagnosis)
|
1196 |
-
# 3. Vital Trends (BP, HR, SpO2, weight) — highlight changes over last 3 visits
|
1197 |
-
# 4. Test Trends (labs, imaging) — flag repeated orders without results
|
1198 |
-
# 5. Assessment (possible conditions)
|
1199 |
-
# 6. Recommendations (labs, imaging, referrals, medication review)
|
1200 |
-
|
1201 |
-
# Rules:
|
1202 |
-
# - Do not invent any information.
|
1203 |
-
# - If BP is rising (e.g., 132/85 → 135/95), flag it.
|
1204 |
-
# - If a medication appears in multiple visits, assume chronic use.
|
1205 |
-
# - If a test is repeated, recommend follow-up.
|
1206 |
-
# - Use professional, concise language.
|
1207 |
-
|
1208 |
-
# --- PATIENT DATA ---
|
1209 |
-
# {patient_text}
|
1210 |
-
|
1211 |
-
# --- SUMMARY ---
|
1212 |
-
# Patient Overview:
|
1213 |
-
# Medical History:
|
1214 |
-
# Vital Trends:
|
1215 |
-
# Test Trends:
|
1216 |
-
# Assessment:
|
1217 |
-
# Recommendations:""".strip()
|
1218 |
-
|
1219 |
-
# # Tokenize
|
1220 |
-
# inputs = self.tokenizer(
|
1221 |
-
# prompt,
|
1222 |
-
# return_tensors="pt",
|
1223 |
-
# truncation=True,
|
1224 |
-
# max_length=self.max_input_tokens,
|
1225 |
-
# padding=True
|
1226 |
-
# ).to(self.device)
|
1227 |
-
|
1228 |
-
# # Generate
|
1229 |
-
# if self.model_type == "seq2seq":
|
1230 |
-
# outputs = self.model.generate(
|
1231 |
-
# **inputs,
|
1232 |
-
# max_new_tokens=self.max_output_tokens,
|
1233 |
-
# num_beams=4,
|
1234 |
-
# temperature=0.7,
|
1235 |
-
# top_p=0.9,
|
1236 |
-
# do_sample=True
|
1237 |
-
# )
|
1238 |
-
# else:
|
1239 |
-
# outputs = self.model.generate(
|
1240 |
-
# **inputs,
|
1241 |
-
# max_new_tokens=self.max_output_tokens,
|
1242 |
-
# temperature=0.7,
|
1243 |
-
# top_p=0.9,
|
1244 |
-
# do_sample=True,
|
1245 |
-
# pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
|
1246 |
-
# eos_token_id=self.tokenizer.eos_token_id
|
1247 |
-
# )
|
1248 |
-
|
1249 |
-
# # Decode
|
1250 |
-
# summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
1251 |
-
|
1252 |
-
# # Extract only the part after "--- SUMMARY ---"
|
1253 |
-
# if "--- SUMMARY ---" in summary:
|
1254 |
-
# summary = summary.split("--- SUMMARY ---")[-1].strip()
|
1255 |
-
|
1256 |
-
# return summary
|
1257 |
-
|
1258 |
-
# except Exception as e:
|
1259 |
-
# logger.error(f"Summary generation failed: {str(e)}", exc_info=True)
|
1260 |
-
# return "Error: Failed to generate clinical summary."
|
1261 |
-
|
1262 |
-
|
1263 |
-
# agent.py
|
1264 |
import torch
|
1265 |
-
import
|
1266 |
-
|
1267 |
-
|
1268 |
-
from
|
|
|
1269 |
|
1270 |
-
|
1271 |
-
|
1272 |
|
1273 |
class PatientSummarizerAgent:
|
1274 |
def __init__(
|
1275 |
self,
|
1276 |
-
model_name: str = "
|
1277 |
-
|
|
|
|
|
|
|
1278 |
):
|
|
|
|
|
1279 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
1280 |
-
|
1281 |
-
|
1282 |
-
|
1283 |
-
|
|
|
|
|
|
|
|
|
1284 |
|
|
|
|
|
1285 |
try:
|
1286 |
-
|
1287 |
-
|
1288 |
-
if
|
1289 |
-
|
1290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1291 |
else:
|
1292 |
-
|
1293 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
1294 |
except Exception as e:
|
1295 |
-
|
1296 |
-
|
1297 |
-
self.
|
1298 |
-
|
1299 |
-
|
1300 |
-
|
1301 |
-
|
1302 |
-
|
1303 |
-
self.
|
1304 |
-
self.model_type =
|
1305 |
-
|
1306 |
-
|
1307 |
-
self
|
1308 |
-
|
1309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1310 |
|
1311 |
def generate_clinical_summary(self, patient_data: Union[List[str], Dict]) -> str:
|
|
|
|
|
|
|
1312 |
try:
|
1313 |
-
#
|
1314 |
-
|
1315 |
-
|
1316 |
-
|
1317 |
-
|
1318 |
-
|
1319 |
-
|
1320 |
-
|
1321 |
-
|
1322 |
-
|
1323 |
-
|
1324 |
-
|
1325 |
-
# Build rich prompt with timeline and analysis. Explicitly instruct to return only sections.
|
1326 |
-
prompt = f"""
|
1327 |
-
You are an expert AI clinical assistant. Analyze the patient's complete visit history and generate a structured, actionable clinical summary history. Use only the provided data. Do not echo any instructions. Return only the sections requested.
|
1328 |
-
|
1329 |
-
--- PATIENT TIMELINE (Narrative across visits) ---
|
1330 |
-
{timeline}
|
1331 |
-
|
1332 |
-
--- CLINICAL INSIGHTS (Computed trends) ---
|
1333 |
-
- Total visits: {insights.get('total_visits', 0)}
|
1334 |
-
- Blood Pressure Trend: {insights.get('bp_trend', 'No data')}
|
1335 |
-
- Weight Trend: {insights.get('weight_trend', 'No data')}
|
1336 |
-
- Chronic Medications: {', '.join(insights.get('chronic_meds', [])) or 'None'}
|
1337 |
-
- Repeated Imaging: {', '.join(insights.get('repeated_imaging', [])) or 'None'}
|
1338 |
-
|
1339 |
-
Provide the clinical summary using exactly these headings, in this order, with concise content under each:
|
1340 |
-
Patient Overview:
|
1341 |
-
Visit History:
|
1342 |
-
Trend Analysis:
|
1343 |
-
Assessment:
|
1344 |
-
Recommendations:
|
1345 |
-
"""
|
1346 |
-
|
1347 |
-
# Tokenize
|
1348 |
-
inputs = self.tokenizer(
|
1349 |
-
prompt,
|
1350 |
-
return_tensors="pt",
|
1351 |
-
truncation=True,
|
1352 |
-
max_length=2048,
|
1353 |
-
padding=True
|
1354 |
-
).to(self.device)
|
1355 |
-
|
1356 |
-
# Generate with model-type aware settings
|
1357 |
-
if self.model_type == "seq2seq":
|
1358 |
-
outputs = self.model.generate(
|
1359 |
-
**inputs,
|
1360 |
-
max_new_tokens=512,
|
1361 |
-
num_beams=4,
|
1362 |
-
temperature=0.7,
|
1363 |
-
top_p=0.9,
|
1364 |
-
do_sample=True,
|
1365 |
-
pad_token_id=self.tokenizer.pad_token_id
|
1366 |
)
|
1367 |
else:
|
1368 |
-
|
1369 |
-
|
1370 |
-
|
|
|
1371 |
temperature=0.7,
|
1372 |
-
top_p=0.9
|
1373 |
-
do_sample=True,
|
1374 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
1375 |
-
eos_token_id=self.tokenizer.eos_token_id
|
1376 |
)
|
1377 |
|
1378 |
-
|
1379 |
-
|
1380 |
-
|
1381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1382 |
|
1383 |
except Exception as e:
|
1384 |
-
|
1385 |
-
|
1386 |
-
|
1387 |
-
|
1388 |
-
|
1389 |
-
|
1390 |
-
|
1391 |
-
|
1392 |
-
|
1393 |
-
|
1394 |
-
|
1395 |
-
|
1396 |
-
|
1397 |
-
#
|
1398 |
-
|
1399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1400 |
]
|
1401 |
-
|
1402 |
-
|
1403 |
-
|
1404 |
-
|
1405 |
-
|
1406 |
-
|
1407 |
-
|
1408 |
-
|
1409 |
-
|
1410 |
-
|
1411 |
-
|
1412 |
-
|
1413 |
-
|
1414 |
-
|
1415 |
-
|
1416 |
-
|
1417 |
-
|
1418 |
-
|
1419 |
-
|
1420 |
-
|
1421 |
-
|
1422 |
-
|
1423 |
-
|
1424 |
-
|
1425 |
-
|
1426 |
-
|
1427 |
-
|
1428 |
-
|
1429 |
-
|
1430 |
-
|
1431 |
-
|
1432 |
-
|
1433 |
-
|
1434 |
-
|
1435 |
-
|
1436 |
-
|
1437 |
-
|
1438 |
-
|
1439 |
-
|
1440 |
-
|
1441 |
-
|
1442 |
-
|
1443 |
-
|
1444 |
-
|
1445 |
-
|
1446 |
-
|
1447 |
-
|
1448 |
-
|
1449 |
-
|
1450 |
-
|
1451 |
-
|
1452 |
-
|
1453 |
-
|
1454 |
-
name = data.get("Patient Name", "Anonymous")
|
1455 |
-
num = data.get("Patient Number", "Unknown")
|
1456 |
-
age = data.get("Age", "Unknown")
|
1457 |
-
gender = data.get("Gender", "Unknown")
|
1458 |
-
dob = data.get("DOB", "N/A")
|
1459 |
-
last = data.get("Last Visit", "N/A")
|
1460 |
-
overview = f"Name: {name}\nPatient ID: {num}\nAge/Sex: {age} / {gender}\nDOB: {dob}\nLast Visit: {last}"
|
1461 |
-
|
1462 |
-
insights = data.get("Insights", {})
|
1463 |
-
# Build a structured visit history from normalized encounters if available
|
1464 |
-
visit_lines: List[str] = []
|
1465 |
-
charts = data.get("chartsummarydtl") or []
|
1466 |
-
if isinstance(charts, list) and charts:
|
1467 |
-
# sort by date ascending
|
1468 |
-
charts_sorted = sorted(charts, key=lambda x: (x.get("chartdate") or x.get("date") or ""))
|
1469 |
-
for ch in charts_sorted:
|
1470 |
-
date = (ch.get("chartdate") or ch.get("date") or "Unknown")[:10]
|
1471 |
-
vitals_str = parse_vitals(ch.get("vitals", []))
|
1472 |
-
diag = ", ".join(ch.get("diagnosis", [])) if isinstance(ch.get("diagnosis", []), list) else ""
|
1473 |
-
meds_list = ch.get("medications", [])
|
1474 |
-
meds_list = meds_list if isinstance(meds_list, list) else []
|
1475 |
-
meds = ", ".join(sorted({m.split("||")[0].strip() if isinstance(m, str) else str(m) for m in meds_list if str(m).strip()}))
|
1476 |
-
labs_list = ch.get("labtests", [])
|
1477 |
-
labs = ", ".join([t.get("name", str(t)) if isinstance(t, dict) else str(t) for t in labs_list if str(t).strip()])
|
1478 |
-
radio_list = ch.get("radiologyorders", [])
|
1479 |
-
radio = ", ".join([r.get("name", str(r)) if isinstance(r, dict) else str(r) for r in radio_list if str(r).strip()])
|
1480 |
-
entry = f"{date}: Vitals: {vitals_str}."
|
1481 |
-
if diag:
|
1482 |
-
entry += f" Diagnosis: {diag}."
|
1483 |
-
if meds:
|
1484 |
-
entry += f" Medications: {meds}."
|
1485 |
-
if labs:
|
1486 |
-
entry += f" Labs: {labs}."
|
1487 |
-
if radio:
|
1488 |
-
entry += f" Imaging: {radio}."
|
1489 |
-
visit_lines.append(entry)
|
1490 |
-
else:
|
1491 |
-
# Fallback to narrative timeline text
|
1492 |
-
timeline = data.get("Timeline", "No visit data available.")
|
1493 |
-
visit_lines = [timeline]
|
1494 |
-
|
1495 |
-
trend_lines: List[str] = [
|
1496 |
-
f"BP Trend: {insights.get('bp_trend', 'No data')}",
|
1497 |
-
f"Weight Trend: {insights.get('weight_trend', 'No data')}",
|
1498 |
]
|
1499 |
-
if
|
1500 |
-
|
1501 |
-
|
1502 |
-
|
1503 |
-
|
1504 |
-
|
1505 |
-
|
1506 |
-
|
1507 |
-
|
1508 |
-
|
1509 |
-
|
1510 |
-
|
1511 |
-
|
1512 |
-
|
1513 |
-
|
1514 |
-
|
1515 |
-
if
|
1516 |
-
|
1517 |
-
|
1518 |
-
|
1519 |
-
|
1520 |
-
|
1521 |
-
|
1522 |
-
|
1523 |
-
|
1524 |
-
|
1525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1526 |
return {
|
1527 |
-
"
|
1528 |
-
"
|
1529 |
-
"
|
1530 |
-
"
|
1531 |
-
|
1532 |
-
}
|
1533 |
-
|
1534 |
-
def _extract_from_timeline(self, timeline: str):
|
1535 |
-
import re
|
1536 |
-
diags = set()
|
1537 |
-
meds = set()
|
1538 |
-
algs = set()
|
1539 |
-
imaging = set()
|
1540 |
-
if not isinstance(timeline, str) or not timeline:
|
1541 |
-
return diags, meds, algs, imaging
|
1542 |
-
# capture multiple occurrences lazily until period
|
1543 |
-
for m in re.finditer(r"Diagnosis:\s*([^\.]+)\.", timeline, flags=re.IGNORECASE):
|
1544 |
-
part = m.group(1).strip()
|
1545 |
-
for item in [x.strip() for x in part.split(",") if x.strip()]:
|
1546 |
-
diags.add(item)
|
1547 |
-
for m in re.finditer(r"Medications prescribed:\s*([^\.]+)\.", timeline, flags=re.IGNORECASE):
|
1548 |
-
part = m.group(1).strip()
|
1549 |
-
for item in [x.strip() for x in part.split(",") if x.strip()]:
|
1550 |
-
meds.add(item)
|
1551 |
-
for m in re.finditer(r"Allergies noted:\s*([^\.]+)\.", timeline, flags=re.IGNORECASE):
|
1552 |
-
part = m.group(1).strip()
|
1553 |
-
for item in [x.strip() for x in part.split(",") if x.strip()]:
|
1554 |
-
algs.add(item)
|
1555 |
-
for m in re.finditer(r"Imaging ordered:\s*([^\.]+)\.", timeline, flags=re.IGNORECASE):
|
1556 |
-
part = m.group(1).strip()
|
1557 |
-
for item in [x.strip() for x in part.split(",") if x.strip()]:
|
1558 |
-
imaging.add(item)
|
1559 |
-
return diags, meds, algs, imaging
|
|
|
1 |
+
import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
+
import warnings
|
4 |
+
import re
|
5 |
+
import json
|
6 |
+
from typing import List, Dict, Union, Optional
|
7 |
+
from textwrap import fill
|
8 |
|
9 |
+
# Suppress non-critical warnings
|
10 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
11 |
|
12 |
class PatientSummarizerAgent:
|
13 |
def __init__(
|
14 |
self,
|
15 |
+
model_name: str = "falconsai/medical_summarization",
|
16 |
+
model_type: str = "summarization",
|
17 |
+
device: Optional[str] = None,
|
18 |
+
max_input_tokens: int = 2048,
|
19 |
+
max_output_tokens: int = 512
|
20 |
):
|
21 |
+
self.model_name = model_name
|
22 |
+
self.model_type = model_type
|
23 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
+
self.max_input_tokens = max_input_tokens
|
25 |
+
self.max_output_tokens = max_output_tokens
|
26 |
+
|
27 |
+
# Initialize model loader through unified model manager
|
28 |
+
self.model_loader = None
|
29 |
+
self._initialize_model_loader()
|
30 |
+
|
31 |
+
print(f"✅ PatientSummarizerAgent initialized with {model_name} ({model_type}) on {self.device}")
|
32 |
|
33 |
+
def _initialize_model_loader(self):
|
34 |
+
"""Initialize the model loader using the unified model manager"""
|
35 |
try:
|
36 |
+
from ..utils.model_manager import model_manager
|
37 |
+
|
38 |
+
# Determine if this is a GGUF model
|
39 |
+
if self.model_type == "gguf" or self.model_name.endswith('.gguf'):
|
40 |
+
# Extract filename if model_name contains path
|
41 |
+
if '/' in self.model_name and not self.model_name.startswith('http'):
|
42 |
+
if self.model_name.endswith('.gguf'):
|
43 |
+
# Full path to .gguf file
|
44 |
+
filename = None
|
45 |
+
else:
|
46 |
+
# HuggingFace repo with filename
|
47 |
+
parts = self.model_name.split('/')
|
48 |
+
if len(parts) >= 2:
|
49 |
+
filename = parts[-1] if parts[-1].endswith('.gguf') else None
|
50 |
+
model_name = '/'.join(parts[:-1]) if filename else self.model_name
|
51 |
+
else:
|
52 |
+
filename = None
|
53 |
+
model_name = self.model_name
|
54 |
+
else:
|
55 |
+
filename = None
|
56 |
+
model_name = self.model_name
|
57 |
+
|
58 |
+
self.model_loader = model_manager.get_model_loader(
|
59 |
+
model_name,
|
60 |
+
"gguf",
|
61 |
+
filename=filename
|
62 |
+
)
|
63 |
else:
|
64 |
+
# Use the specified model type
|
65 |
+
self.model_loader = model_manager.get_model_loader(
|
66 |
+
self.model_name,
|
67 |
+
self.model_type
|
68 |
+
)
|
69 |
+
|
70 |
+
print(f"✅ Model loader initialized: {self.model_name} ({self.model_type})")
|
71 |
+
|
72 |
except Exception as e:
|
73 |
+
print(f"❌ Failed to initialize model loader: {e}")
|
74 |
+
# Create a fallback loader
|
75 |
+
self._create_fallback_loader()
|
76 |
+
|
77 |
+
def _create_fallback_loader(self):
|
78 |
+
"""Create a fallback text-based loader when model loading fails"""
|
79 |
+
class FallbackLoader:
|
80 |
+
def __init__(self, model_name: str, model_type: str):
|
81 |
+
self.model_name = model_name
|
82 |
+
self.model_type = model_type
|
83 |
+
self.name = "fallback_text"
|
84 |
+
|
85 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
86 |
+
# Simple template-based response
|
87 |
+
sections = [
|
88 |
+
"## Clinical Assessment\nBased on the provided information, this appears to be a medical case requiring clinical review.",
|
89 |
+
"## Key Trends & Changes\nPlease review the patient data for any significant changes or trends.",
|
90 |
+
"## Plan & Suggested Actions\nConsider consulting with a healthcare provider for proper medical assessment.",
|
91 |
+
"## Direct Guidance for Physician\nThis summary was generated using a fallback method. Please review all patient data thoroughly."
|
92 |
+
]
|
93 |
+
return "\n\n".join(sections)
|
94 |
+
|
95 |
+
def generate_full_summary(self, prompt: str, **kwargs) -> str:
|
96 |
+
return self.generate(prompt, **kwargs)
|
97 |
+
|
98 |
+
self.model_loader = FallbackLoader(self.model_name, self.model_type)
|
99 |
+
print(f"⚠️ Using fallback loader for {self.model_name}")
|
100 |
|
101 |
def generate_clinical_summary(self, patient_data: Union[List[str], Dict]) -> str:
|
102 |
+
"""Generate a comprehensive clinical summary using the unified model manager"""
|
103 |
+
print(f"✨ Generating clinical summary using model: {self.model_name} ({self.model_type})...")
|
104 |
+
|
105 |
try:
|
106 |
+
# Build the narrative prompt
|
107 |
+
narrative_history = self.build_chronological_narrative(patient_data)
|
108 |
+
print(f"\n--- Prompt Sent to Model (truncated) ---\n{fill(narrative_history, width=80)[:1000]}...")
|
109 |
+
|
110 |
+
# Generate summary using the model loader
|
111 |
+
if hasattr(self.model_loader, 'generate_full_summary'):
|
112 |
+
# GGUF models support full summary generation
|
113 |
+
raw_summary_text = self.model_loader.generate_full_summary(
|
114 |
+
narrative_history,
|
115 |
+
max_tokens=self.max_output_tokens,
|
116 |
+
max_loops=1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
)
|
118 |
else:
|
119 |
+
# Other models use standard generation
|
120 |
+
raw_summary_text = self.model_loader.generate(
|
121 |
+
narrative_history,
|
122 |
+
max_new_tokens=self.max_output_tokens,
|
123 |
temperature=0.7,
|
124 |
+
top_p=0.9
|
|
|
|
|
|
|
125 |
)
|
126 |
|
127 |
+
print(f"\n--- Raw Model Output ---\n{fill(raw_summary_text, width=80)}")
|
128 |
+
|
129 |
+
# Format the output
|
130 |
+
formatted_report = self.format_clinical_output(raw_summary_text, patient_data)
|
131 |
+
evaluation_report = self.evaluate_summary_against_guidelines(raw_summary_text, patient_data)
|
132 |
+
|
133 |
+
# Combine final output
|
134 |
+
final_output = (
|
135 |
+
f"\n{'='*80}\n"
|
136 |
+
f" FINAL CLINICAL SUMMARY REPORT\n"
|
137 |
+
f"{'='*80}\n"
|
138 |
+
f"{formatted_report}\n\n"
|
139 |
+
f"{'='*80}\n"
|
140 |
+
f" SIMULATED EVALUATION REPORT\n"
|
141 |
+
f"{'='*80}\n"
|
142 |
+
f"{evaluation_report}"
|
143 |
+
)
|
144 |
+
return final_output
|
145 |
|
146 |
except Exception as e:
|
147 |
+
print(f"❌ Error during summary generation: {e}")
|
148 |
+
import traceback
|
149 |
+
traceback.print_exc()
|
150 |
+
return f"Error generating summary: {str(e)}"
|
151 |
+
|
152 |
+
def build_chronological_narrative(self, patient_data: dict) -> str:
|
153 |
+
"""Builds a chronological narrative from multi-encounter patient history."""
|
154 |
+
result = patient_data.get("result", {})
|
155 |
+
narrative = []
|
156 |
+
|
157 |
+
# Past Medical History
|
158 |
+
narrative.append(f"Past Medical History: {', '.join(result.get('past_medical_history', []))}.")
|
159 |
+
|
160 |
+
# Social History
|
161 |
+
social = result.get('social_history', 'Not specified.')
|
162 |
+
narrative.append(f"Social History: {social}.")
|
163 |
+
|
164 |
+
# Allergies
|
165 |
+
allergies = ', '.join(result.get('allergies', ['None']))
|
166 |
+
narrative.append(f"Allergies: {allergies}.")
|
167 |
+
|
168 |
+
# Loop through encounters chronologically
|
169 |
+
for enc in result.get("encounters", []):
|
170 |
+
encounter_str = (
|
171 |
+
f"Encounter on {enc['visit_date']}: "
|
172 |
+
f"Chief Complaint: '{enc['chief_complaint']}'. "
|
173 |
+
f"Symptoms: {enc.get('symptoms', 'None reported')}. "
|
174 |
+
f"Diagnosis: {', '.join(enc['diagnosis'])}. "
|
175 |
+
f"Doctor's Notes: {enc['dr_notes']}. "
|
176 |
+
)
|
177 |
+
if enc.get('vitals'):
|
178 |
+
encounter_str += f"Vitals: {', '.join([f'{k}: {v}' for k, v in enc['vitals'].items()])}. "
|
179 |
+
if enc.get('lab_results'):
|
180 |
+
encounter_str += f"Labs: {', '.join([f'{k}: {v}' for k, v in enc['lab_results'].items()])}. "
|
181 |
+
if enc.get('medications'):
|
182 |
+
encounter_str += f"Medications: {', '.join(enc['medications'])}. "
|
183 |
+
if enc.get('treatment'):
|
184 |
+
encounter_str += f"Treatment: {enc['treatment']}."
|
185 |
+
narrative.append(encounter_str)
|
186 |
+
|
187 |
+
return "\n".join(narrative)
|
188 |
+
|
189 |
+
def format_clinical_output(self, raw_summary: str, patient_data: dict) -> str:
|
190 |
+
"""Formats the raw AI-generated summary into a structured, doctor-friendly report."""
|
191 |
+
result = patient_data.get("result", {})
|
192 |
+
last_encounter = result.get("encounters", [{}])[-1] if result.get("encounters") else result
|
193 |
+
|
194 |
+
# Consolidate active problems
|
195 |
+
all_diagnoses_raw = set(result.get('past_medical_history', []))
|
196 |
+
for enc in result.get('encounters', []):
|
197 |
+
all_diagnoses_raw.update(enc.get('diagnosis', []))
|
198 |
+
cleaned_diagnoses = sorted({
|
199 |
+
re.sub(r'\s*\([^)]*\)', '', dx).strip() for dx in all_diagnoses_raw
|
200 |
+
})
|
201 |
+
|
202 |
+
# Consolidate current medications
|
203 |
+
all_medications = set()
|
204 |
+
for enc in result.get('encounters', []):
|
205 |
+
all_medications.update(enc.get('medications', []))
|
206 |
+
current_meds = sorted(all_medications)
|
207 |
+
|
208 |
+
# Report Header
|
209 |
+
report = "\n==============================================\n"
|
210 |
+
report += " CLINICAL SUMMARY REPORT\n"
|
211 |
+
report += "==============================================\n"
|
212 |
+
report += f"Generated On: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
|
213 |
+
|
214 |
+
# Patient Overview
|
215 |
+
report += "\n--- PATIENT OVERVIEW ---\n"
|
216 |
+
report += f"Name: {result.get('patientname', 'Unknown')}\n"
|
217 |
+
report += f"Patient ID: {result.get('patientnumber', 'Unknown')}\n"
|
218 |
+
gender = result.get('gender', 'Unknown')
|
219 |
+
report += f"Age/Sex: {result.get('agey', 'Unknown')} {gender[0] if gender != 'Unknown' else 'U'}\n"
|
220 |
+
report += f"Allergies: {', '.join(result.get('allergies', ['None']))}\n"
|
221 |
+
|
222 |
+
# Social History
|
223 |
+
report += "\n--- SOCIAL HISTORY ---\n"
|
224 |
+
report += fill(result.get('social_history', 'Not specified.'), width=80) + "\n"
|
225 |
+
|
226 |
+
# Immediate Attention
|
227 |
+
report += "\n--- IMMEDIATE ATTENTION (Most Recent Encounter) ---\n"
|
228 |
+
report += f"Date of Event: {last_encounter.get('visit_date', 'Unknown')}\n"
|
229 |
+
report += f"Chief Complaint: {last_encounter.get('chief_complaint', 'Not specified')}\n"
|
230 |
+
if last_encounter.get('vitals'):
|
231 |
+
vitals_str = ', '.join([f'{k}: {v}' for k, v in last_encounter['vitals'].items()])
|
232 |
+
report += f"Vitals: {vitals_str}\n"
|
233 |
+
critical_diagnoses = [
|
234 |
+
dx for dx in last_encounter.get('diagnosis', [])
|
235 |
+
if any(kw in dx.lower() for kw in ['acute', 'new onset', 'fall', 'afib', 'kidney injury'])
|
236 |
]
|
237 |
+
if critical_diagnoses:
|
238 |
+
report += f"Critical New Diagnoses: {', '.join(critical_diagnoses)}\n"
|
239 |
+
report += f"Doctor's Notes: {last_encounter.get('dr_notes', 'N/A')}\n"
|
240 |
+
|
241 |
+
# Active Problem List
|
242 |
+
report += "\n--- ACTIVE PROBLEM LIST (Consolidated) ---\n"
|
243 |
+
report += "\n".join(f"- {dx}" for dx in cleaned_diagnoses) + "\n"
|
244 |
+
|
245 |
+
# Current Medications
|
246 |
+
report += "\n--- CURRENT MEDICATION LIST (Consolidated) ---\n"
|
247 |
+
report += "\n".join(f"- {med}" for med in current_meds) + "\n"
|
248 |
+
|
249 |
+
# Procedures
|
250 |
+
procedures = set()
|
251 |
+
for enc in result.get('encounters', []):
|
252 |
+
if 'treatment' in enc and 'PCI' in enc['treatment']:
|
253 |
+
procedures.add(enc['treatment'])
|
254 |
+
if procedures:
|
255 |
+
report += "\n--- PROCEDURES & SURGERIES ---\n"
|
256 |
+
report += "\n".join(f"- {proc}" for proc in sorted(procedures)) + "\n"
|
257 |
+
|
258 |
+
# AI-Generated Narrative
|
259 |
+
report += "\n--- AI-GENERATED CLINICAL NARRATIVE ---\n"
|
260 |
+
report += fill(raw_summary, width=80) + "\n"
|
261 |
+
|
262 |
+
# Placeholder sections if not in model output
|
263 |
+
if "Assessment and Plan" not in raw_summary:
|
264 |
+
report += "\n--- ASSESSMENT, PLAN AND NEXT STEPS (AI-Generated) ---\n"
|
265 |
+
report += "The model did not generate a structured assessment and plan. Please review clinical context.\n"
|
266 |
+
|
267 |
+
if "Clinical Pathway" not in raw_summary:
|
268 |
+
report += "\n--- CLINICAL PATHWAY (AI-Generated) ---\n"
|
269 |
+
report += "No clinical pathway was generated. Consider next steps based on active issues.\n"
|
270 |
+
|
271 |
+
return report
|
272 |
+
|
273 |
+
def evaluate_summary_against_guidelines(self, summary_text: str, patient_data: dict) -> str:
|
274 |
+
"""Simulated evaluation of summary against clinical guidelines."""
|
275 |
+
result = patient_data.get("result", {})
|
276 |
+
last_enc = result.get("encounters", [{}])[-1] if result.get("encounters") else {}
|
277 |
+
|
278 |
+
summary_lower = summary_text.lower()
|
279 |
+
evaluation = (
|
280 |
+
"\n==============================================\n"
|
281 |
+
" AI SUMMARY EVALUATION & GUIDELINE CHECK\n"
|
282 |
+
"==============================================\n"
|
283 |
+
)
|
284 |
+
|
285 |
+
# Keyword-based accuracy
|
286 |
+
critical_keywords = [
|
287 |
+
"fall", "dizziness", "atrial fibrillation", "afib", "rvr", "kidney", "ckd",
|
288 |
+
"diabetes", "anticoagulation", "warfarin", "aspirin", "statin", "metformin",
|
289 |
+
"gout", "angina", "pci", "bph", "hypertension", "metoprolol", "clopidogrel"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
]
|
291 |
+
found = [kw for kw in critical_keywords if kw in summary_lower]
|
292 |
+
score = (len(found) / len(critical_keywords)) * 10
|
293 |
+
evaluation += f"\n1. KEYWORD ACCURACY SCORE: {score:.1f}/10\n"
|
294 |
+
evaluation += f" - Found {len(found)} out of {len(critical_keywords)} critical concepts.\n"
|
295 |
+
|
296 |
+
# Guideline checks
|
297 |
+
evaluation += "\n2. CLINICAL GUIDELINE COMMENTARY (SIMULATED):\n"
|
298 |
+
|
299 |
+
has_afib = any("atrial fibrillation" in dx.lower() for dx in last_enc.get('diagnosis', []))
|
300 |
+
on_anticoag = any("warfarin" in med.lower() or "apixaban" in med.lower() for med in last_enc.get('medications', []))
|
301 |
+
if has_afib:
|
302 |
+
evaluation += " - ✅ Patient with Atrial Fibrillation is on anticoagulation.\n" if on_anticoag \
|
303 |
+
else " - ❌ Atrial Fibrillation present but no anticoagulant prescribed.\n"
|
304 |
+
|
305 |
+
has_mi = any("myocardial infarction" in hx.lower() for hx in result.get('past_medical_history', []))
|
306 |
+
on_statin = any("atorvastatin" in med.lower() or "statin" in med.lower() for med in last_enc.get('medications', []))
|
307 |
+
if has_mi:
|
308 |
+
evaluation += " - ✅ Patient with MI history is on statin therapy.\n" if on_statin \
|
309 |
+
else " - ❌ Patient with MI history is not on statin therapy.\n"
|
310 |
+
|
311 |
+
has_aki = any("acute kidney injury" in dx.lower() for dx in last_enc.get('diagnosis', []))
|
312 |
+
acei_held = "hold" in last_enc.get('dr_notes', '').lower() and "lisinopril" in last_enc.get('dr_notes', '')
|
313 |
+
if has_aki:
|
314 |
+
evaluation += " - ✅ AKI noted and ACE inhibitor was appropriately held.\n" if acei_held \
|
315 |
+
else " - ⚠️ AKI present but ACE inhibitor not documented as held.\n"
|
316 |
+
|
317 |
+
evaluation += (
|
318 |
+
"\nDisclaimer: This is a simulated evaluation and not a substitute for clinical judgment.\n"
|
319 |
+
)
|
320 |
+
return evaluation
|
321 |
+
|
322 |
+
def update_model(self, model_name: str, model_type: str):
|
323 |
+
"""Update the model used by this agent"""
|
324 |
+
self.model_name = model_name
|
325 |
+
self.model_type = model_type
|
326 |
+
self._initialize_model_loader()
|
327 |
+
print(f"✅ Model updated to: {model_name} ({model_type})")
|
328 |
+
|
329 |
+
def get_model_info(self) -> dict:
|
330 |
+
"""Get information about the current model"""
|
331 |
+
if self.model_loader:
|
332 |
+
return self.model_loader.get_model_info()
|
333 |
return {
|
334 |
+
"type": "unknown",
|
335 |
+
"model_name": self.model_name,
|
336 |
+
"model_type": self.model_type,
|
337 |
+
"loaded": False
|
338 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ai_med_extract/api/model_management.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dynamic Model Management API
|
3 |
+
Allows runtime loading, switching, and management of different model types
|
4 |
+
"""
|
5 |
+
|
6 |
+
from flask import Blueprint, request, jsonify
|
7 |
+
import logging
|
8 |
+
from typing import Dict, Any, Optional
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from ..utils.model_manager import model_manager
|
12 |
+
from ..utils.model_config import (
|
13 |
+
get_default_model,
|
14 |
+
get_fallback_model,
|
15 |
+
detect_model_type,
|
16 |
+
validate_model_config,
|
17 |
+
get_model_info
|
18 |
+
)
|
19 |
+
|
20 |
+
# Configure logging
|
21 |
+
logging.basicConfig(level=logging.INFO)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
# Create Blueprint
|
25 |
+
model_management_bp = Blueprint('model_management', __name__, url_prefix='/api/models')
|
26 |
+
|
27 |
+
@model_management_bp.route('/load', methods=['POST'])
|
28 |
+
def load_model():
|
29 |
+
"""
|
30 |
+
Load a new model with specified name and type
|
31 |
+
|
32 |
+
Request body:
|
33 |
+
{
|
34 |
+
"model_name": "microsoft/Phi-3-mini-4k-instruct-gguf",
|
35 |
+
"model_type": "gguf",
|
36 |
+
"filename": "Phi-3-mini-4k-instruct-q4.gguf", # Optional for GGUF
|
37 |
+
"force_reload": false # Optional, force reload even if cached
|
38 |
+
}
|
39 |
+
"""
|
40 |
+
try:
|
41 |
+
data = request.get_json()
|
42 |
+
if not data:
|
43 |
+
return jsonify({"error": "No data provided"}), 400
|
44 |
+
|
45 |
+
model_name = data.get("model_name")
|
46 |
+
model_type = data.get("model_type")
|
47 |
+
filename = data.get("filename")
|
48 |
+
force_reload = data.get("force_reload", False)
|
49 |
+
|
50 |
+
if not model_name:
|
51 |
+
return jsonify({"error": "model_name is required"}), 400
|
52 |
+
|
53 |
+
# Auto-detect model type if not provided
|
54 |
+
if not model_type:
|
55 |
+
model_type = detect_model_type(model_name)
|
56 |
+
logger.info(f"Auto-detected model type: {model_type} for {model_name}")
|
57 |
+
|
58 |
+
# Validate model configuration
|
59 |
+
validation = validate_model_config(model_name, model_type)
|
60 |
+
if not validation["valid"]:
|
61 |
+
return jsonify({
|
62 |
+
"error": "Invalid model configuration",
|
63 |
+
"validation": validation
|
64 |
+
}), 400
|
65 |
+
|
66 |
+
# Load the model
|
67 |
+
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
68 |
+
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
69 |
+
|
70 |
+
if start_time:
|
71 |
+
start_time.record()
|
72 |
+
|
73 |
+
loader = model_manager.get_model_loader(model_name, model_type, filename, force_reload)
|
74 |
+
|
75 |
+
if end_time:
|
76 |
+
end_time.record()
|
77 |
+
torch.cuda.synchronize()
|
78 |
+
load_time = start_time.elapsed_time(end_time) / 1000.0 # Convert to seconds
|
79 |
+
else:
|
80 |
+
load_time = None
|
81 |
+
|
82 |
+
# Get model information
|
83 |
+
model_info = loader.get_model_info()
|
84 |
+
model_info["load_time_seconds"] = load_time
|
85 |
+
|
86 |
+
return jsonify({
|
87 |
+
"success": True,
|
88 |
+
"message": f"Model {model_name} ({model_type}) loaded successfully",
|
89 |
+
"model_info": model_info,
|
90 |
+
"validation": validation
|
91 |
+
}), 200
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
logger.error(f"Failed to load model: {str(e)}", exc_info=True)
|
95 |
+
return jsonify({
|
96 |
+
"success": False,
|
97 |
+
"error": f"Model loading failed: {str(e)}"
|
98 |
+
}), 500
|
99 |
+
|
100 |
+
@model_management_bp.route('/generate', methods=['POST'])
|
101 |
+
def generate_text():
|
102 |
+
"""
|
103 |
+
Generate text using a specific model
|
104 |
+
|
105 |
+
Request body:
|
106 |
+
{
|
107 |
+
"model_name": "microsoft/Phi-3-mini-4k-instruct-gguf",
|
108 |
+
"model_type": "gguf",
|
109 |
+
"filename": "Phi-3-mini-4k-instruct-q4.gguf", # Optional for GGUF
|
110 |
+
"prompt": "Generate a medical summary for...",
|
111 |
+
"max_tokens": 512,
|
112 |
+
"temperature": 0.7,
|
113 |
+
"top_p": 0.95
|
114 |
+
}
|
115 |
+
"""
|
116 |
+
try:
|
117 |
+
data = request.get_json()
|
118 |
+
if not data:
|
119 |
+
return jsonify({"error": "No data provided"}), 400
|
120 |
+
|
121 |
+
model_name = data.get("model_name")
|
122 |
+
model_type = data.get("model_type")
|
123 |
+
filename = data.get("filename")
|
124 |
+
prompt = data.get("prompt")
|
125 |
+
|
126 |
+
if not all([model_name, prompt]):
|
127 |
+
return jsonify({"error": "model_name and prompt are required"}), 400
|
128 |
+
|
129 |
+
# Auto-detect model type if not provided
|
130 |
+
if not model_type:
|
131 |
+
model_type = detect_model_type(model_name)
|
132 |
+
|
133 |
+
# Generate text
|
134 |
+
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
135 |
+
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
136 |
+
|
137 |
+
if start_time:
|
138 |
+
start_time.record()
|
139 |
+
|
140 |
+
generated_text = model_manager.generate_text(
|
141 |
+
model_name,
|
142 |
+
model_type,
|
143 |
+
prompt,
|
144 |
+
filename,
|
145 |
+
**{k: v for k, v in data.items() if k not in ["model_name", "model_type", "filename", "prompt"]}
|
146 |
+
)
|
147 |
+
|
148 |
+
if end_time:
|
149 |
+
end_time.record()
|
150 |
+
torch.cuda.synchronize()
|
151 |
+
generation_time = start_time.elapsed_time(end_time) / 1000.0
|
152 |
+
else:
|
153 |
+
generation_time = None
|
154 |
+
|
155 |
+
return jsonify({
|
156 |
+
"success": True,
|
157 |
+
"generated_text": generated_text,
|
158 |
+
"model_name": model_name,
|
159 |
+
"model_type": model_type,
|
160 |
+
"generation_time_seconds": generation_time,
|
161 |
+
"text_length": len(generated_text)
|
162 |
+
}), 200
|
163 |
+
|
164 |
+
except Exception as e:
|
165 |
+
logger.error(f"Text generation failed: {str(e)}", exc_info=True)
|
166 |
+
return jsonify({
|
167 |
+
"success": False,
|
168 |
+
"error": f"Text generation failed: {str(e)}"
|
169 |
+
}), 500
|
170 |
+
|
171 |
+
@model_management_bp.route('/info', methods=['GET'])
|
172 |
+
def get_model_information():
|
173 |
+
"""
|
174 |
+
Get information about a specific model or all loaded models
|
175 |
+
|
176 |
+
Query parameters:
|
177 |
+
- model_name: Optional, specific model to get info for
|
178 |
+
- model_type: Optional, filter by model type
|
179 |
+
"""
|
180 |
+
try:
|
181 |
+
model_name = request.args.get("model_name")
|
182 |
+
model_type = request.args.get("model_type")
|
183 |
+
|
184 |
+
if model_name:
|
185 |
+
# Get info for specific model
|
186 |
+
if not model_type:
|
187 |
+
model_type = detect_model_type(model_name)
|
188 |
+
|
189 |
+
validation = validate_model_config(model_name, model_type)
|
190 |
+
model_info = get_model_info(model_name, model_type)
|
191 |
+
|
192 |
+
return jsonify({
|
193 |
+
"success": True,
|
194 |
+
"model_info": model_info,
|
195 |
+
"validation": validation
|
196 |
+
}), 200
|
197 |
+
else:
|
198 |
+
# Get info for all loaded models
|
199 |
+
loaded_models = model_manager.list_loaded_models()
|
200 |
+
|
201 |
+
# Filter by type if specified
|
202 |
+
if model_type:
|
203 |
+
loaded_models = {
|
204 |
+
k: v for k, v in loaded_models.items()
|
205 |
+
if v.get("model_type") == model_type
|
206 |
+
}
|
207 |
+
|
208 |
+
return jsonify({
|
209 |
+
"success": True,
|
210 |
+
"loaded_models": loaded_models,
|
211 |
+
"total_models": len(loaded_models)
|
212 |
+
}), 200
|
213 |
+
|
214 |
+
except Exception as e:
|
215 |
+
logger.error(f"Failed to get model information: {str(e)}", exc_info=True)
|
216 |
+
return jsonify({
|
217 |
+
"success": False,
|
218 |
+
"error": f"Failed to get model information: {str(e)}"
|
219 |
+
}), 500
|
220 |
+
|
221 |
+
@model_management_bp.route('/defaults', methods=['GET'])
|
222 |
+
def get_default_models():
|
223 |
+
"""
|
224 |
+
Get default models for different model types
|
225 |
+
"""
|
226 |
+
try:
|
227 |
+
from ..utils.model_config import DEFAULT_MODELS, SPACES_OPTIMIZED_MODELS
|
228 |
+
|
229 |
+
return jsonify({
|
230 |
+
"success": True,
|
231 |
+
"default_models": DEFAULT_MODELS,
|
232 |
+
"spaces_optimized_models": SPACES_OPTIMIZED_MODELS
|
233 |
+
}), 200
|
234 |
+
|
235 |
+
except Exception as e:
|
236 |
+
logger.error(f"Failed to get default models: {str(e)}", exc_info=True)
|
237 |
+
return jsonify({
|
238 |
+
"success": False,
|
239 |
+
"error": f"Failed to get default models: {str(e)}"
|
240 |
+
}), 500
|
241 |
+
|
242 |
+
@model_management_bp.route('/clear_cache', methods=['POST'])
|
243 |
+
def clear_model_cache():
|
244 |
+
"""
|
245 |
+
Clear the model cache and free memory
|
246 |
+
"""
|
247 |
+
try:
|
248 |
+
# Get cache info before clearing
|
249 |
+
loaded_models = model_manager.list_loaded_models()
|
250 |
+
cache_size = len(loaded_models)
|
251 |
+
|
252 |
+
# Clear cache
|
253 |
+
model_manager.clear_cache()
|
254 |
+
|
255 |
+
return jsonify({
|
256 |
+
"success": True,
|
257 |
+
"message": f"Model cache cleared successfully",
|
258 |
+
"cleared_models": cache_size,
|
259 |
+
"memory_freed": "GPU and CPU memory cleared"
|
260 |
+
}), 200
|
261 |
+
|
262 |
+
except Exception as e:
|
263 |
+
logger.error(f"Failed to clear cache: {str(e)}", exc_info=True)
|
264 |
+
return jsonify({
|
265 |
+
"success": False,
|
266 |
+
"error": f"Failed to clear cache: {str(e)}"
|
267 |
+
}), 500
|
268 |
+
|
269 |
+
@model_management_bp.route('/switch', methods=['POST'])
|
270 |
+
def switch_model():
|
271 |
+
"""
|
272 |
+
Switch the model used by a specific agent
|
273 |
+
|
274 |
+
Request body:
|
275 |
+
{
|
276 |
+
"agent_name": "patient_summarizer",
|
277 |
+
"model_name": "microsoft/Phi-3-mini-4k-instruct-gguf",
|
278 |
+
"model_type": "gguf",
|
279 |
+
"filename": "Phi-3-mini-4k-instruct-q4.gguf" # Optional for GGUF
|
280 |
+
}
|
281 |
+
"""
|
282 |
+
try:
|
283 |
+
data = request.get_json()
|
284 |
+
if not data:
|
285 |
+
return jsonify({"error": "No data provided"}), 400
|
286 |
+
|
287 |
+
agent_name = data.get("agent_name")
|
288 |
+
model_name = data.get("model_name")
|
289 |
+
model_type = data.get("model_type")
|
290 |
+
filename = data.get("filename")
|
291 |
+
|
292 |
+
if not all([agent_name, model_name]):
|
293 |
+
return jsonify({"error": "agent_name and model_name are required"}), 400
|
294 |
+
|
295 |
+
# Auto-detect model type if not provided
|
296 |
+
if not model_type:
|
297 |
+
model_type = detect_model_type(model_name)
|
298 |
+
|
299 |
+
# Validate model configuration
|
300 |
+
validation = validate_model_config(model_name, model_type)
|
301 |
+
if not validation["valid"]:
|
302 |
+
return jsonify({
|
303 |
+
"error": "Invalid model configuration",
|
304 |
+
"validation": validation
|
305 |
+
}), 400
|
306 |
+
|
307 |
+
# Get the agent from the current app context
|
308 |
+
from flask import current_app
|
309 |
+
agents = getattr(current_app, 'agents', {})
|
310 |
+
|
311 |
+
if agent_name not in agents:
|
312 |
+
return jsonify({
|
313 |
+
"error": f"Agent '{agent_name}' not found",
|
314 |
+
"available_agents": list(agents.keys())
|
315 |
+
}), 404
|
316 |
+
|
317 |
+
agent = agents[agent_name]
|
318 |
+
|
319 |
+
# Update the agent's model if it supports it
|
320 |
+
if hasattr(agent, 'update_model'):
|
321 |
+
agent.update_model(model_name, model_type)
|
322 |
+
message = f"Agent '{agent_name}' model updated to {model_name} ({model_type})"
|
323 |
+
elif hasattr(agent, 'model_loader'):
|
324 |
+
# Try to update the model loader
|
325 |
+
try:
|
326 |
+
from ..utils.model_manager import model_manager
|
327 |
+
agent.model_loader = model_manager.get_model_loader(model_name, model_type, filename)
|
328 |
+
message = f"Agent '{agent_name}' model loader updated to {model_name} ({model_type})"
|
329 |
+
except Exception as e:
|
330 |
+
return jsonify({
|
331 |
+
"error": f"Failed to update agent model loader: {str(e)}"
|
332 |
+
}), 500
|
333 |
+
else:
|
334 |
+
return jsonify({
|
335 |
+
"error": f"Agent '{agent_name}' does not support model switching"
|
336 |
+
}), 400
|
337 |
+
|
338 |
+
return jsonify({
|
339 |
+
"success": True,
|
340 |
+
"message": message,
|
341 |
+
"agent_name": agent_name,
|
342 |
+
"model_name": model_name,
|
343 |
+
"model_type": model_type,
|
344 |
+
"validation": validation
|
345 |
+
}), 200
|
346 |
+
|
347 |
+
except Exception as e:
|
348 |
+
logger.error(f"Failed to switch model: {str(e)}", exc_info=True)
|
349 |
+
return jsonify({
|
350 |
+
"success": False,
|
351 |
+
"error": f"Failed to switch model: {str(e)}"
|
352 |
+
}), 500
|
353 |
+
|
354 |
+
@model_management_bp.route('/health', methods=['GET'])
|
355 |
+
def model_health_check():
|
356 |
+
"""
|
357 |
+
Health check for the model management system
|
358 |
+
"""
|
359 |
+
try:
|
360 |
+
# Check if model manager is accessible
|
361 |
+
loaded_models = model_manager.list_loaded_models()
|
362 |
+
|
363 |
+
# Check GPU memory if available
|
364 |
+
gpu_info = {}
|
365 |
+
if torch.cuda.is_available():
|
366 |
+
gpu_info = {
|
367 |
+
"available": True,
|
368 |
+
"device_count": torch.cuda.device_count(),
|
369 |
+
"current_device": torch.cuda.current_device(),
|
370 |
+
"memory_allocated": f"{torch.cuda.memory_allocated() / 1024**3:.2f} GB",
|
371 |
+
"memory_reserved": f"{torch.cuda.memory_reserved() / 1024**3:.2f} GB"
|
372 |
+
}
|
373 |
+
else:
|
374 |
+
gpu_info = {"available": False}
|
375 |
+
|
376 |
+
return jsonify({
|
377 |
+
"success": True,
|
378 |
+
"status": "healthy",
|
379 |
+
"model_manager": "operational",
|
380 |
+
"loaded_models_count": len(loaded_models),
|
381 |
+
"gpu_info": gpu_info,
|
382 |
+
"timestamp": torch.cuda.Event(enable_timing=True).elapsed_time(torch.cuda.Event(enable_timing=True)) if torch.cuda.is_available() else None
|
383 |
+
}), 200
|
384 |
+
|
385 |
+
except Exception as e:
|
386 |
+
logger.error(f"Health check failed: {str(e)}", exc_info=True)
|
387 |
+
return jsonify({
|
388 |
+
"success": False,
|
389 |
+
"status": "unhealthy",
|
390 |
+
"error": f"Health check failed: {str(e)}"
|
391 |
+
}), 500
|
392 |
+
|
393 |
+
# Register the blueprint
|
394 |
+
def register_model_management_routes(app):
|
395 |
+
"""Register model management routes with the Flask app"""
|
396 |
+
app.register_blueprint(model_management_bp)
|
397 |
+
logger.info("Model management routes registered successfully")
|
ai_med_extract/api/routes.py
CHANGED
@@ -14,7 +14,6 @@ from transformers import (
|
|
14 |
pipeline as transformers_pipeline
|
15 |
)
|
16 |
from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent
|
17 |
-
agent = PatientSummarizerAgent(model_name="falconsai/medical_summarization")
|
18 |
from ai_med_extract.agents.summarizer import SummarizerAgent
|
19 |
from ai_med_extract.utils.file_utils import (
|
20 |
allowed_file,
|
@@ -23,278 +22,89 @@ from ai_med_extract.utils.file_utils import (
|
|
23 |
get_data_from_storage,
|
24 |
)
|
25 |
from ..utils.validation import clean_result, validate_patient_name
|
26 |
-
|
27 |
-
|
28 |
-
from ai_med_extract.utils.patient_summary_utils import clean_patient_data, flatten_to_string_list
|
29 |
import time
|
30 |
|
31 |
-
#
|
32 |
-
|
33 |
-
|
34 |
-
def get_gguf_pipeline(model_name, filename=None):
|
35 |
-
key = (model_name, filename)
|
36 |
-
if key not in GGUF_MODEL_CACHE:
|
37 |
-
try:
|
38 |
-
from ai_med_extract.utils.model_loader_gguf import GGUFModelPipeline, create_fallback_pipeline
|
39 |
-
import time
|
40 |
-
|
41 |
-
# Add timeout for model loading
|
42 |
-
start_time = time.time()
|
43 |
-
timeout = 300 # 5 minutes timeout
|
44 |
-
|
45 |
-
# Try to load the GGUF model
|
46 |
-
try:
|
47 |
-
GGUF_MODEL_CACHE[key] = GGUFModelPipeline(model_name, filename, timeout=timeout)
|
48 |
-
load_time = time.time() - start_time
|
49 |
-
print(f"[GGUF] Model loaded successfully in {load_time:.2f}s: {model_name}")
|
50 |
-
except Exception as e:
|
51 |
-
load_time = time.time() - start_time
|
52 |
-
print(f"[GGUF] Failed to load model {model_name} after {load_time:.2f}s: {e}")
|
53 |
-
|
54 |
-
# If model loading fails, use fallback
|
55 |
-
print("[GGUF] Using fallback pipeline")
|
56 |
-
GGUF_MODEL_CACHE[key] = create_fallback_pipeline()
|
57 |
-
|
58 |
-
except Exception as e:
|
59 |
-
print(f"[GGUF] Critical error in model loading: {e}")
|
60 |
-
# Create a basic fallback
|
61 |
-
from ai_med_extract.utils.model_loader_gguf import create_fallback_pipeline
|
62 |
-
GGUF_MODEL_CACHE[key] = create_fallback_pipeline()
|
63 |
-
|
64 |
-
return GGUF_MODEL_CACHE[key]
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
-
def get_qa_pipeline(qa_model_type, qa_model_name):
|
|
|
68 |
if not qa_model_type or not qa_model_name:
|
69 |
raise ValueError("Both qa_model_type and qa_model_name must be provided")
|
70 |
-
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
key = (qa_model_type, qa_model_name)
|
85 |
-
if key in get_qa_pipeline.cache:
|
86 |
-
return get_qa_pipeline.cache[key]
|
87 |
-
|
88 |
try:
|
89 |
-
|
90 |
-
if "Qwen/Qwen-7B-Chat" in qa_model_name:
|
91 |
-
qa_model_name = "Qwen/Qwen-1_8B-Chat"
|
92 |
-
elif "Llama" in qa_model_name:
|
93 |
-
qa_model_name = "facebook/opt-125m"
|
94 |
-
|
95 |
-
# Load tokenizer with trust_remote_code=True for custom tokenizers
|
96 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
97 |
-
qa_model_name,
|
98 |
-
trust_remote_code=True,
|
99 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
100 |
-
)
|
101 |
-
|
102 |
-
# Load model with memory optimizations
|
103 |
-
try:
|
104 |
-
model = AutoModelForCausalLM.from_pretrained(
|
105 |
-
qa_model_name,
|
106 |
-
device_map="auto",
|
107 |
-
torch_dtype=torch.float32, # Use float32 for better compatibility
|
108 |
-
trust_remote_code=True,
|
109 |
-
low_cpu_mem_usage=True,
|
110 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
111 |
-
)
|
112 |
-
except Exception as e:
|
113 |
-
# Try loading with a simpler model
|
114 |
-
fallback_model = "facebook/bart-base"
|
115 |
-
model = AutoModelForCausalLM.from_pretrained(
|
116 |
-
fallback_model,
|
117 |
-
device_map="auto",
|
118 |
-
torch_dtype=torch.float32,
|
119 |
-
trust_remote_code=True,
|
120 |
-
low_cpu_mem_usage=True,
|
121 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
122 |
-
)
|
123 |
-
|
124 |
-
# Create pipeline with memory optimizations
|
125 |
-
pipeline = transformers_pipeline(
|
126 |
-
task=qa_model_type,
|
127 |
-
model=model,
|
128 |
-
tokenizer=tokenizer,
|
129 |
-
device_map="auto",
|
130 |
-
torch_dtype=torch.float32
|
131 |
-
)
|
132 |
-
|
133 |
-
get_qa_pipeline.cache[key] = pipeline
|
134 |
-
return pipeline
|
135 |
-
|
136 |
except Exception as e:
|
|
|
137 |
raise
|
138 |
|
139 |
-
def run_qa_pipeline(qa_pipeline, question, context):
|
140 |
"""
|
141 |
-
Run QA pipeline for
|
142 |
"""
|
143 |
if not qa_pipeline or not question or not context:
|
144 |
raise ValueError("Pipeline, question and context are required")
|
145 |
-
|
146 |
-
qa_model_type = getattr(qa_pipeline, '_qa_model_type', None)
|
147 |
|
148 |
try:
|
149 |
-
|
|
|
|
|
150 |
prompt = f"Question: {question}\nContext: {context}\nAnswer:"
|
151 |
-
result = qa_pipeline(prompt, max_new_tokens=128
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
return {'answer': answer}
|
156 |
-
return {'answer': str(result)}
|
157 |
-
else:
|
158 |
result = qa_pipeline(question=question, context=context)
|
|
|
|
|
159 |
return result
|
|
|
|
|
160 |
except Exception as e:
|
|
|
161 |
raise
|
162 |
|
163 |
-
def get_ner_pipeline(ner_model_type, ner_model_name):
|
164 |
-
if not ner_model_type or not ner_model_name:
|
165 |
-
raise ValueError("Both ner_model_type and ner_model_name must be provided")
|
166 |
-
|
167 |
-
if not hasattr(get_ner_pipeline, "cache"):
|
168 |
-
get_ner_pipeline.cache = {}
|
169 |
-
|
170 |
-
# For Hugging Face Spaces, we need to be memory efficient
|
171 |
-
import torch
|
172 |
-
torch.cuda.empty_cache() # Clear GPU memory before loading model
|
173 |
-
|
174 |
-
# Set default tensor type
|
175 |
-
torch.set_default_tensor_type(torch.FloatTensor)
|
176 |
-
if torch.cuda.is_available():
|
177 |
-
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
178 |
-
|
179 |
-
key = (ner_model_type, ner_model_name)
|
180 |
-
if key in get_ner_pipeline.cache:
|
181 |
-
return get_ner_pipeline.cache[key]
|
182 |
-
|
183 |
-
try:
|
184 |
-
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
|
185 |
-
|
186 |
-
# Clear any existing models from memory
|
187 |
-
if torch.cuda.is_available():
|
188 |
-
torch.cuda.empty_cache()
|
189 |
-
|
190 |
-
# Load tokenizer
|
191 |
-
try:
|
192 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
193 |
-
ner_model_name,
|
194 |
-
trust_remote_code=True,
|
195 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
196 |
-
)
|
197 |
-
except Exception as e:
|
198 |
-
# Try loading with a simpler model
|
199 |
-
fallback_model = "dslim/bert-base-NER"
|
200 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
201 |
-
fallback_model,
|
202 |
-
trust_remote_code=True,
|
203 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
204 |
-
)
|
205 |
-
|
206 |
-
# Load model with memory optimizations
|
207 |
-
try:
|
208 |
-
# For NER models, we'll use CPU if device_map='auto' is not supported
|
209 |
-
try:
|
210 |
-
model = AutoModelForTokenClassification.from_pretrained(
|
211 |
-
ner_model_name,
|
212 |
-
trust_remote_code=True,
|
213 |
-
device_map="auto",
|
214 |
-
low_cpu_mem_usage=True,
|
215 |
-
torch_dtype=torch.float32,
|
216 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
217 |
-
)
|
218 |
-
except ValueError as e:
|
219 |
-
if "device_map='auto'" in str(e):
|
220 |
-
model = AutoModelForTokenClassification.from_pretrained(
|
221 |
-
ner_model_name,
|
222 |
-
trust_remote_code=True,
|
223 |
-
low_cpu_mem_usage=True,
|
224 |
-
torch_dtype=torch.float32,
|
225 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
226 |
-
)
|
227 |
-
else:
|
228 |
-
raise
|
229 |
-
except Exception as e:
|
230 |
-
# Try loading with a simpler model
|
231 |
-
fallback_model = "dslim/bert-base-NER"
|
232 |
-
model = AutoModelForTokenClassification.from_pretrained(
|
233 |
-
fallback_model,
|
234 |
-
trust_remote_code=True,
|
235 |
-
low_cpu_mem_usage=True,
|
236 |
-
torch_dtype=torch.float32,
|
237 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
238 |
-
)
|
239 |
-
|
240 |
-
# Create pipeline with appropriate device configuration
|
241 |
-
try:
|
242 |
-
qa_pipeline = pipeline(
|
243 |
-
task=ner_model_type,
|
244 |
-
model=model,
|
245 |
-
tokenizer=tokenizer,
|
246 |
-
device_map="auto",
|
247 |
-
torch_dtype=torch.float32
|
248 |
-
)
|
249 |
-
except ValueError as e:
|
250 |
-
if "device_map='auto'" in str(e):
|
251 |
-
qa_pipeline = pipeline(
|
252 |
-
task=ner_model_type,
|
253 |
-
model=model,
|
254 |
-
tokenizer=tokenizer,
|
255 |
-
device=-1, # Use CPU
|
256 |
-
torch_dtype=torch.float32
|
257 |
-
)
|
258 |
-
else:
|
259 |
-
raise
|
260 |
-
|
261 |
-
# Cache the pipeline
|
262 |
-
get_ner_pipeline.cache[key] = qa_pipeline
|
263 |
-
return qa_pipeline
|
264 |
-
|
265 |
-
except Exception as e:
|
266 |
-
raise
|
267 |
-
|
268 |
-
|
269 |
-
def get_summarizer_pipeline(summarizer_model_type, summarizer_model_name):
|
270 |
-
if not hasattr(get_summarizer_pipeline, "cache"):
|
271 |
-
get_summarizer_pipeline.cache = {}
|
272 |
-
key = (summarizer_model_type, summarizer_model_name)
|
273 |
-
if key not in get_summarizer_pipeline.cache:
|
274 |
-
import torch
|
275 |
-
from transformers import pipeline
|
276 |
-
|
277 |
-
# Use float16 only if CUDA is available, else use float32
|
278 |
-
if torch.cuda.is_available():
|
279 |
-
dtype = torch.float16
|
280 |
-
device = 0
|
281 |
-
device_map = "auto"
|
282 |
-
else:
|
283 |
-
dtype = torch.float32
|
284 |
-
device = -1
|
285 |
-
device_map = None
|
286 |
-
|
287 |
-
get_summarizer_pipeline.cache[key] = pipeline(
|
288 |
-
task=summarizer_model_type,
|
289 |
-
model=summarizer_model_name,
|
290 |
-
trust_remote_code=True,
|
291 |
-
device=device,
|
292 |
-
torch_dtype=dtype,
|
293 |
-
**({"device_map": device_map} if device_map else {})
|
294 |
-
)
|
295 |
-
return get_summarizer_pipeline.cache[key]
|
296 |
-
|
297 |
-
|
298 |
def register_routes(app, agents):
|
299 |
from ai_med_extract.utils.openvino_summarizer_utils import (
|
300 |
parse_ehr_chartsummarydtl, visits_sorted, compute_deltas, build_compact_baseline, delta_to_text, build_main_prompt, validate_and_compare_summaries
|
@@ -336,22 +146,31 @@ def register_routes(app, agents):
|
|
336 |
# Model selection logic (model_name, model_type)
|
337 |
model_name = data.get("model_name") or "microsoft/Phi-3-mini-4k-instruct"
|
338 |
model_type = data.get("model_type") or "text-generation"
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
if not pipeline:
|
346 |
-
return jsonify({"error": "Model pipeline not available"}), 500
|
347 |
|
348 |
# Run inference
|
349 |
import torch
|
350 |
torch.set_num_threads(2)
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
|
356 |
# Update state
|
357 |
with state_lock:
|
@@ -369,6 +188,7 @@ def register_routes(app, agents):
|
|
369 |
}), 200
|
370 |
except Exception as e:
|
371 |
return jsonify({"error": f"Failed to generate summary: {str(e)}"}), 500
|
|
|
372 |
# Configure upload directory based on environment
|
373 |
import os
|
374 |
|
@@ -391,7 +211,8 @@ def register_routes(app, agents):
|
|
391 |
PHIScrubberAgent = agents["phi_scrubber"]
|
392 |
Summarizer_Agent = agents["summarizer"]
|
393 |
MedicalDataExtractorAgent = agents["medical_data_extractor"]
|
394 |
-
whisper_model = agents["whisper_model"]
|
|
|
395 |
|
396 |
@app.route("/upload", methods=["POST"])
|
397 |
def upload_file():
|
@@ -619,7 +440,6 @@ def register_routes(app, agents):
|
|
619 |
os.remove(temp_path)
|
620 |
return jsonify({"error": str(e)}), 500
|
621 |
|
622 |
-
|
623 |
def group_by_category(data):
|
624 |
grouped = defaultdict(list)
|
625 |
for item in data:
|
@@ -649,7 +469,7 @@ def register_routes(app, agents):
|
|
649 |
return list(reversed(reversed_unique))
|
650 |
|
651 |
def chunk_text(text, tokenizer, max_tokens=256, overlap=100):
|
652 |
-
|
653 |
input_ids = tokenizer.encode(
|
654 |
text,
|
655 |
add_special_tokens=False
|
@@ -714,7 +534,6 @@ def register_routes(app, agents):
|
|
714 |
|
715 |
return extracted
|
716 |
|
717 |
-
|
718 |
def process_chunk(generator, chunk, idx):
|
719 |
prompt = f"""
|
720 |
[INST] <<SYS>>
|
@@ -767,12 +586,20 @@ def register_routes(app, agents):
|
|
767 |
torch.cuda.empty_cache()
|
768 |
|
769 |
# Process with memory optimizations
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
776 |
|
777 |
return idx, output
|
778 |
except Exception as e:
|
@@ -792,27 +619,19 @@ def register_routes(app, agents):
|
|
792 |
return jsonify({"error": "Missing 'extracted_data' in request"}), 400
|
793 |
|
794 |
try:
|
795 |
-
|
796 |
-
|
797 |
-
trust_remote_code=True,
|
798 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
799 |
-
)
|
800 |
-
|
801 |
-
model = AutoModelForCausalLM.from_pretrained(
|
802 |
-
qa_model_name,
|
803 |
-
device_map="auto",
|
804 |
-
torch_dtype=torch.float32,
|
805 |
-
trust_remote_code=True,
|
806 |
-
low_cpu_mem_usage=True,
|
807 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
808 |
-
)
|
809 |
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
|
|
|
|
|
|
|
|
816 |
|
817 |
except Exception as e:
|
818 |
return jsonify({"error": f"Could not load model: {str(e)}"}), 500
|
@@ -856,7 +675,6 @@ def register_routes(app, agents):
|
|
856 |
# Clean and group results for this file
|
857 |
if all_extracted:
|
858 |
deduped = deduplicate_extractions(all_extracted)
|
859 |
-
# cleaned_json = clean_result()
|
860 |
grouped_data = group_by_category(deduped)
|
861 |
else:
|
862 |
grouped_data = {"error": "No valid data extracted"}
|
@@ -873,8 +691,6 @@ def register_routes(app, agents):
|
|
873 |
print("✅ Extraction complete.")
|
874 |
return jsonify(structured_response)
|
875 |
|
876 |
-
|
877 |
-
|
878 |
@app.route("/api/generate_summary", methods=["POST"])
|
879 |
def generate_summary():
|
880 |
data = request.json
|
@@ -886,7 +702,7 @@ def register_routes(app, agents):
|
|
886 |
except Exception:
|
887 |
clean_text = context
|
888 |
try:
|
889 |
-
summary = SummarizerAgent.generate_summary(Summarizer_Agent,clean_text)
|
890 |
return jsonify({"summary": summary}), 200
|
891 |
except Exception as e:
|
892 |
return jsonify({"error": f"Summary generation failed: {str(e)}"}), 500
|
@@ -1005,14 +821,10 @@ def register_routes(app, agents):
|
|
1005 |
"error": f"Request handling failed: {str(e)}"
|
1006 |
}), 500
|
1007 |
|
1008 |
-
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
@app.route('/generate_patient_summary', methods=['POST'])
|
1013 |
def generate_patient_summary():
|
1014 |
"""
|
1015 |
-
Enhanced: Uses
|
1016 |
"""
|
1017 |
from ai_med_extract.utils.openvino_summarizer_utils import (
|
1018 |
parse_ehr_chartsummarydtl, visits_sorted, compute_deltas, build_compact_baseline, delta_to_text, build_main_prompt, validate_and_compare_summaries
|
@@ -1084,93 +896,68 @@ def register_routes(app, agents):
|
|
1084 |
delta_text = delta_to_text(delta)
|
1085 |
prompt = build_main_prompt(old_summary, baseline, delta_text)
|
1086 |
t_model_load_start = time.time()
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
repo_id, filename = model_name.rsplit('/', 1)
|
1097 |
-
pipeline = get_gguf_pipeline(repo_id, filename)
|
1098 |
else:
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
-
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
1108 |
-
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
|
1133 |
-
|
1134 |
-
|
1135 |
-
|
1136 |
-
|
1137 |
-
|
1138 |
-
|
1139 |
-
|
1140 |
-
|
1141 |
-
|
1142 |
-
|
1143 |
-
|
1144 |
-
|
1145 |
-
|
1146 |
-
|
1147 |
-
|
1148 |
-
|
1149 |
-
|
1150 |
-
|
1151 |
-
summary = pipeline.generate(prompt)
|
1152 |
-
return jsonify({"summary": summary})
|
1153 |
-
except Exception as e:
|
1154 |
-
return jsonify({"error": f"GGUF model generation failed: {str(e)}"}), 500
|
1155 |
-
inputs = pipeline.tokenizer([prompt], return_tensors="pt")
|
1156 |
-
outputs = pipeline.model.generate(**inputs, max_new_tokens=500, do_sample=False, pad_token_id=pipeline.tokenizer.eos_token_id or 32000)
|
1157 |
-
text = pipeline.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
1158 |
-
new_summary = text.split("Now generate the complete, updated clinical summary with all four sections in a markdown format:")[-1].strip()
|
1159 |
-
# For other models, after extracting new_summary:
|
1160 |
-
markdown_summary = summary_to_markdown(new_summary)
|
1161 |
-
with state_lock:
|
1162 |
-
patient_state["visits"] = all_visits
|
1163 |
-
patient_state["last_summary"] = markdown_summary
|
1164 |
-
validation_report = validate_and_compare_summaries(old_summary, markdown_summary, "Update")
|
1165 |
-
# Remove undefined timing variables and only log steps that are actually measured
|
1166 |
-
total_time = time.time() - start_total
|
1167 |
-
print(f"[TIMING] API call: {t_api_end-t_api_start:.2f}s, TOTAL: {total_time:.2f}s")
|
1168 |
-
return jsonify({
|
1169 |
-
"summary": markdown_summary,
|
1170 |
-
"validation": validation_report,
|
1171 |
-
"baseline": baseline,
|
1172 |
-
"delta": delta_text
|
1173 |
-
}), 200
|
1174 |
except requests.exceptions.Timeout:
|
1175 |
return jsonify({"error": "Request to EHR API timed out"}), 504
|
1176 |
except requests.exceptions.RequestException as e:
|
@@ -1183,7 +970,6 @@ def register_routes(app, agents):
|
|
1183 |
def home():
|
1184 |
return "Medical Data Extraction API is running!", 200
|
1185 |
|
1186 |
-
|
1187 |
def summary_to_markdown(summary):
|
1188 |
import re
|
1189 |
# Remove '- answer:' and similar artifacts
|
|
|
14 |
pipeline as transformers_pipeline
|
15 |
)
|
16 |
from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent
|
|
|
17 |
from ai_med_extract.agents.summarizer import SummarizerAgent
|
18 |
from ai_med_extract.utils.file_utils import (
|
19 |
allowed_file,
|
|
|
22 |
get_data_from_storage,
|
23 |
)
|
24 |
from ..utils.validation import clean_result, validate_patient_name
|
25 |
+
from ai_med_extract.utils.patient_summary_utils import clean_patient_data, flatten_to_string_list
|
|
|
|
|
26 |
import time
|
27 |
|
28 |
+
# Configure logging
|
29 |
+
logging.basicConfig(level=logging.INFO)
|
30 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
def get_model_pipeline(model_name: str, model_type: str, filename: str = None):
|
33 |
+
"""
|
34 |
+
Unified function to get any model pipeline using the unified model manager
|
35 |
+
"""
|
36 |
+
try:
|
37 |
+
from ..utils.model_manager import model_manager
|
38 |
+
|
39 |
+
# Get the model loader
|
40 |
+
loader = model_manager.get_model_loader(model_name, model_type, filename)
|
41 |
+
|
42 |
+
# Return the loaded pipeline
|
43 |
+
return loader.load()
|
44 |
+
|
45 |
+
except Exception as e:
|
46 |
+
logger.error(f"Failed to get model pipeline for {model_name} ({model_type}): {e}")
|
47 |
+
raise RuntimeError(f"Model pipeline creation failed: {str(e)}")
|
48 |
|
49 |
+
def get_qa_pipeline(qa_model_type: str, qa_model_name: str):
|
50 |
+
"""Get QA pipeline using unified model manager"""
|
51 |
if not qa_model_type or not qa_model_name:
|
52 |
raise ValueError("Both qa_model_type and qa_model_name must be provided")
|
|
|
53 |
|
54 |
+
try:
|
55 |
+
return get_model_pipeline(qa_model_name, qa_model_type)
|
56 |
+
except Exception as e:
|
57 |
+
logger.error(f"QA pipeline creation failed: {e}")
|
58 |
+
raise
|
59 |
+
|
60 |
+
def get_ner_pipeline(ner_model_type: str, ner_model_name: str):
|
61 |
+
"""Get NER pipeline using unified model manager"""
|
62 |
+
if not ner_model_type or not ner_model_name:
|
63 |
+
raise ValueError("Both ner_model_type and ner_model_name must be provided")
|
64 |
|
65 |
+
try:
|
66 |
+
return get_model_pipeline(ner_model_name, ner_model_type)
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f"NER pipeline creation failed: {e}")
|
69 |
+
raise
|
70 |
+
|
71 |
+
def get_summarizer_pipeline(summarizer_model_type: str, summarizer_model_name: str):
|
72 |
+
"""Get summarizer pipeline using unified model manager"""
|
73 |
+
if not summarizer_model_type or not summarizer_model_name:
|
74 |
+
raise ValueError("Both summarizer_model_type and summarizer_model_name must be provided")
|
75 |
|
|
|
|
|
|
|
|
|
76 |
try:
|
77 |
+
return get_model_pipeline(summarizer_model_name, summarizer_model_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
except Exception as e:
|
79 |
+
logger.error(f"Summarizer pipeline creation failed: {e}")
|
80 |
raise
|
81 |
|
82 |
+
def run_qa_pipeline(qa_pipeline, question: str, context: str):
|
83 |
"""
|
84 |
+
Run QA pipeline for any model type
|
85 |
"""
|
86 |
if not qa_pipeline or not question or not context:
|
87 |
raise ValueError("Pipeline, question and context are required")
|
|
|
|
|
88 |
|
89 |
try:
|
90 |
+
# Handle different pipeline types
|
91 |
+
if hasattr(qa_pipeline, 'generate'):
|
92 |
+
# Custom pipeline with generate method
|
93 |
prompt = f"Question: {question}\nContext: {context}\nAnswer:"
|
94 |
+
result = qa_pipeline.generate(prompt, max_new_tokens=128)
|
95 |
+
return {'answer': result}
|
96 |
+
elif hasattr(qa_pipeline, '__call__'):
|
97 |
+
# Standard transformers pipeline
|
|
|
|
|
|
|
98 |
result = qa_pipeline(question=question, context=context)
|
99 |
+
if isinstance(result, list) and result:
|
100 |
+
return result[0]
|
101 |
return result
|
102 |
+
else:
|
103 |
+
raise ValueError("Unsupported pipeline type")
|
104 |
except Exception as e:
|
105 |
+
logger.error(f"QA pipeline execution failed: {e}")
|
106 |
raise
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
def register_routes(app, agents):
|
109 |
from ai_med_extract.utils.openvino_summarizer_utils import (
|
110 |
parse_ehr_chartsummarydtl, visits_sorted, compute_deltas, build_compact_baseline, delta_to_text, build_main_prompt, validate_and_compare_summaries
|
|
|
146 |
# Model selection logic (model_name, model_type)
|
147 |
model_name = data.get("model_name") or "microsoft/Phi-3-mini-4k-instruct"
|
148 |
model_type = data.get("model_type") or "text-generation"
|
149 |
+
|
150 |
+
# Use unified model manager
|
151 |
+
try:
|
152 |
+
pipeline = get_model_pipeline(model_name, model_type)
|
153 |
+
except Exception as e:
|
154 |
+
return jsonify({"error": f"Model loading failed: {str(e)}"}), 500
|
|
|
|
|
155 |
|
156 |
# Run inference
|
157 |
import torch
|
158 |
torch.set_num_threads(2)
|
159 |
+
|
160 |
+
if hasattr(pipeline, 'generate'):
|
161 |
+
# Custom pipeline with generate method
|
162 |
+
new_summary = pipeline.generate(prompt, max_new_tokens=400)
|
163 |
+
else:
|
164 |
+
# Standard transformers pipeline
|
165 |
+
inputs = pipeline.tokenizer([prompt], return_tensors="pt")
|
166 |
+
outputs = pipeline.model.generate(
|
167 |
+
**inputs,
|
168 |
+
max_new_tokens=400,
|
169 |
+
do_sample=False,
|
170 |
+
pad_token_id=pipeline.tokenizer.eos_token_id or 32000
|
171 |
+
)
|
172 |
+
text = pipeline.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
173 |
+
new_summary = text.split("Now generate the complete, updated clinical summary with all four sections in a markdown format:")[-1].strip()
|
174 |
|
175 |
# Update state
|
176 |
with state_lock:
|
|
|
188 |
}), 200
|
189 |
except Exception as e:
|
190 |
return jsonify({"error": f"Failed to generate summary: {str(e)}"}), 500
|
191 |
+
|
192 |
# Configure upload directory based on environment
|
193 |
import os
|
194 |
|
|
|
211 |
PHIScrubberAgent = agents["phi_scrubber"]
|
212 |
Summarizer_Agent = agents["summarizer"]
|
213 |
MedicalDataExtractorAgent = agents["medical_data_extractor"]
|
214 |
+
whisper_model = agents["whisper_model"]
|
215 |
+
model_manager = agents.get("model_manager")
|
216 |
|
217 |
@app.route("/upload", methods=["POST"])
|
218 |
def upload_file():
|
|
|
440 |
os.remove(temp_path)
|
441 |
return jsonify({"error": str(e)}), 500
|
442 |
|
|
|
443 |
def group_by_category(data):
|
444 |
grouped = defaultdict(list)
|
445 |
for item in data:
|
|
|
469 |
return list(reversed(reversed_unique))
|
470 |
|
471 |
def chunk_text(text, tokenizer, max_tokens=256, overlap=100):
|
472 |
+
# Tokenize with memory optimizations
|
473 |
input_ids = tokenizer.encode(
|
474 |
text,
|
475 |
add_special_tokens=False
|
|
|
534 |
|
535 |
return extracted
|
536 |
|
|
|
537 |
def process_chunk(generator, chunk, idx):
|
538 |
prompt = f"""
|
539 |
[INST] <<SYS>>
|
|
|
586 |
torch.cuda.empty_cache()
|
587 |
|
588 |
# Process with memory optimizations
|
589 |
+
if hasattr(generator, 'generate'):
|
590 |
+
output = generator.generate(
|
591 |
+
prompt,
|
592 |
+
max_new_tokens=1024,
|
593 |
+
do_sample=False,
|
594 |
+
temperature=0.3,
|
595 |
+
)
|
596 |
+
else:
|
597 |
+
output = generator(
|
598 |
+
prompt,
|
599 |
+
max_new_tokens=1024,
|
600 |
+
do_sample=False,
|
601 |
+
temperature=0.3,
|
602 |
+
)[0]["generated_text"]
|
603 |
|
604 |
return idx, output
|
605 |
except Exception as e:
|
|
|
619 |
return jsonify({"error": "Missing 'extracted_data' in request"}), 400
|
620 |
|
621 |
try:
|
622 |
+
# Use unified model manager
|
623 |
+
generator = get_model_pipeline(qa_model_name, qa_model_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
|
625 |
+
# Get tokenizer for chunking
|
626 |
+
if hasattr(generator, 'tokenizer'):
|
627 |
+
tokenizer = generator.tokenizer
|
628 |
+
else:
|
629 |
+
# Load tokenizer separately if needed
|
630 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
631 |
+
qa_model_name,
|
632 |
+
trust_remote_code=True,
|
633 |
+
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
634 |
+
)
|
635 |
|
636 |
except Exception as e:
|
637 |
return jsonify({"error": f"Could not load model: {str(e)}"}), 500
|
|
|
675 |
# Clean and group results for this file
|
676 |
if all_extracted:
|
677 |
deduped = deduplicate_extractions(all_extracted)
|
|
|
678 |
grouped_data = group_by_category(deduped)
|
679 |
else:
|
680 |
grouped_data = {"error": "No valid data extracted"}
|
|
|
691 |
print("✅ Extraction complete.")
|
692 |
return jsonify(structured_response)
|
693 |
|
|
|
|
|
694 |
@app.route("/api/generate_summary", methods=["POST"])
|
695 |
def generate_summary():
|
696 |
data = request.json
|
|
|
702 |
except Exception:
|
703 |
clean_text = context
|
704 |
try:
|
705 |
+
summary = SummarizerAgent.generate_summary(Summarizer_Agent, clean_text)
|
706 |
return jsonify({"summary": summary}), 200
|
707 |
except Exception as e:
|
708 |
return jsonify({"error": f"Summary generation failed: {str(e)}"}), 500
|
|
|
821 |
"error": f"Request handling failed: {str(e)}"
|
822 |
}), 500
|
823 |
|
|
|
|
|
|
|
|
|
824 |
@app.route('/generate_patient_summary', methods=['POST'])
|
825 |
def generate_patient_summary():
|
826 |
"""
|
827 |
+
Enhanced: Uses unified model manager for any model type including GGUF for patient summary generation.
|
828 |
"""
|
829 |
from ai_med_extract.utils.openvino_summarizer_utils import (
|
830 |
parse_ehr_chartsummarydtl, visits_sorted, compute_deltas, build_compact_baseline, delta_to_text, build_main_prompt, validate_and_compare_summaries
|
|
|
896 |
delta_text = delta_to_text(delta)
|
897 |
prompt = build_main_prompt(old_summary, baseline, delta_text)
|
898 |
t_model_load_start = time.time()
|
899 |
+
|
900 |
+
# Use unified model manager for any model type
|
901 |
+
try:
|
902 |
+
# Handle GGUF models with filename extraction
|
903 |
+
filename = None
|
904 |
+
if model_type == "gguf" and '/' in model_name:
|
905 |
+
if model_name.endswith('.gguf'):
|
906 |
+
# Full path to .gguf file
|
907 |
+
pass
|
|
|
|
|
908 |
else:
|
909 |
+
# HuggingFace repo with filename
|
910 |
+
parts = model_name.split('/')
|
911 |
+
if len(parts) >= 2 and parts[-1].endswith('.gguf'):
|
912 |
+
filename = parts[-1]
|
913 |
+
model_name = '/'.join(parts[:-1])
|
914 |
+
|
915 |
+
pipeline = get_model_pipeline(model_name, model_type, filename)
|
916 |
+
|
917 |
+
# Generate summary based on model type
|
918 |
+
if model_type == "gguf" and hasattr(pipeline, 'generate_full_summary'):
|
919 |
+
summary_raw = pipeline.generate_full_summary(prompt, max_tokens=512, max_loops=1)
|
920 |
+
new_summary = summary_raw.split("Now generate the complete, updated clinical summary with all four sections in a markdown format:")[-1].strip()
|
921 |
+
if not new_summary.strip():
|
922 |
+
new_summary = summary_raw
|
923 |
+
else:
|
924 |
+
# Standard generation for other model types
|
925 |
+
if hasattr(pipeline, 'generate'):
|
926 |
+
new_summary = pipeline.generate(prompt, max_new_tokens=500)
|
927 |
+
else:
|
928 |
+
# Transformers pipeline
|
929 |
+
inputs = pipeline.tokenizer([prompt], return_tensors="pt")
|
930 |
+
outputs = pipeline.model.generate(
|
931 |
+
**inputs,
|
932 |
+
max_new_tokens=500,
|
933 |
+
do_sample=False,
|
934 |
+
pad_token_id=pipeline.tokenizer.eos_token_id or 32000
|
935 |
+
)
|
936 |
+
text = pipeline.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
937 |
+
new_summary = text.split("Now generate the complete, updated clinical summary with all four sections in a markdown format:")[-1].strip()
|
938 |
+
|
939 |
+
# Convert to markdown and update state
|
940 |
+
markdown_summary = summary_to_markdown(new_summary)
|
941 |
+
with state_lock:
|
942 |
+
patient_state["visits"] = all_visits
|
943 |
+
patient_state["last_summary"] = markdown_summary
|
944 |
+
|
945 |
+
validation_report = validate_and_compare_summaries(old_summary, markdown_summary, "Update")
|
946 |
+
|
947 |
+
total_time = time.time() - start_total
|
948 |
+
print(f"[TIMING] API call: {t_api_end-t_api_start:.2f}s, TOTAL: {total_time:.2f}s")
|
949 |
+
|
950 |
+
return jsonify({
|
951 |
+
"summary": markdown_summary,
|
952 |
+
"validation": validation_report,
|
953 |
+
"baseline": baseline,
|
954 |
+
"delta": delta_text
|
955 |
+
}), 200
|
956 |
+
|
957 |
+
except Exception as e:
|
958 |
+
logger.error(f"Model processing failed: {str(e)}", exc_info=True)
|
959 |
+
return jsonify({"error": f"Model processing failed: {str(e)}"}), 500
|
960 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
961 |
except requests.exceptions.Timeout:
|
962 |
return jsonify({"error": "Request to EHR API timed out"}), 504
|
963 |
except requests.exceptions.RequestException as e:
|
|
|
970 |
def home():
|
971 |
return "Medical Data Extraction API is running!", 200
|
972 |
|
|
|
973 |
def summary_to_markdown(summary):
|
974 |
import re
|
975 |
# Remove '- answer:' and similar artifacts
|
ai_med_extract/app.py
CHANGED
@@ -11,9 +11,9 @@ from .agents.summarizer import SummarizerAgent
|
|
11 |
from .agents.medical_data_extractor import MedicalDataExtractorAgent
|
12 |
from .agents.medical_data_extractor import MedicalDocDataExtractorAgent
|
13 |
from .agents.patient_summary_agent import PatientSummarizerAgent
|
|
|
14 |
import torch
|
15 |
|
16 |
-
|
17 |
# Load environment variables
|
18 |
load_dotenv()
|
19 |
|
@@ -50,7 +50,6 @@ app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 # 100 MB max file size
|
|
50 |
|
51 |
# Set cache directories
|
52 |
CACHE_DIRS = {
|
53 |
-
'HF_HOME': '/tmp/huggingface',
|
54 |
'HF_HOME': '/tmp/huggingface',
|
55 |
'XDG_CACHE_HOME': '/tmp',
|
56 |
'TORCH_HOME': '/tmp/torch',
|
@@ -61,79 +60,7 @@ for env_var, path in CACHE_DIRS.items():
|
|
61 |
os.environ[env_var] = path
|
62 |
os.makedirs(path, exist_ok=True)
|
63 |
|
64 |
-
#
|
65 |
-
class LazyModelLoader:
|
66 |
-
def __init__(self, model_name, model_type, fallback_model=None, max_retries=2):
|
67 |
-
self.model_name = model_name
|
68 |
-
self.model_type = model_type
|
69 |
-
self.fallback_model = fallback_model
|
70 |
-
self._model = None
|
71 |
-
self._tokenizer = None
|
72 |
-
self._pipeline = None
|
73 |
-
self._retries = 0
|
74 |
-
self.max_retries = max_retries
|
75 |
-
|
76 |
-
def load(self):
|
77 |
-
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
78 |
-
import torch
|
79 |
-
|
80 |
-
if self._pipeline is None:
|
81 |
-
try:
|
82 |
-
logging.info(f"Loading model: {self.model_name} (attempt {self._retries + 1})")
|
83 |
-
torch.cuda.empty_cache()
|
84 |
-
|
85 |
-
self._tokenizer = AutoTokenizer.from_pretrained(
|
86 |
-
self.model_name,
|
87 |
-
trust_remote_code=True,
|
88 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
89 |
-
)
|
90 |
-
|
91 |
-
if self.model_type == "text-generation":
|
92 |
-
self._model = AutoModelForCausalLM.from_pretrained(
|
93 |
-
self.model_name,
|
94 |
-
trust_remote_code=True,
|
95 |
-
device_map="auto",
|
96 |
-
low_cpu_mem_usage=True,
|
97 |
-
torch_dtype=torch.float16,
|
98 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
99 |
-
)
|
100 |
-
else:
|
101 |
-
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
102 |
-
self._model = AutoModelForSeq2SeqLM.from_pretrained(
|
103 |
-
self.model_name,
|
104 |
-
trust_remote_code=True,
|
105 |
-
device_map="auto",
|
106 |
-
low_cpu_mem_usage=True,
|
107 |
-
torch_dtype=dtype,
|
108 |
-
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
109 |
-
)
|
110 |
-
|
111 |
-
device = 0 if torch.cuda.is_available() else -1
|
112 |
-
self._pipeline = pipeline(
|
113 |
-
task=self.model_type,
|
114 |
-
model=self._model,
|
115 |
-
tokenizer=self._tokenizer,
|
116 |
-
)
|
117 |
-
logging.info(f"Model loaded successfully: {self.model_name}")
|
118 |
-
return self._pipeline
|
119 |
-
|
120 |
-
except Exception as e:
|
121 |
-
logging.error(f"Error loading model '{self.model_name}': {e}", exc_info=True)
|
122 |
-
self._retries += 1
|
123 |
-
|
124 |
-
if self._retries >= self.max_retries:
|
125 |
-
raise RuntimeError(f"Exceeded retry limit for model: {self.model_name}")
|
126 |
-
|
127 |
-
# Attempt fallback if it's different from current
|
128 |
-
if self.fallback_model and self.fallback_model != self.model_name:
|
129 |
-
logging.warning(f"Falling back to model: {self.fallback_model}")
|
130 |
-
self.model_name = self.fallback_model
|
131 |
-
return self.load()
|
132 |
-
else:
|
133 |
-
raise RuntimeError(f"Fallback failed or not set for model: {self.model_name}")
|
134 |
-
return self._pipeline
|
135 |
-
|
136 |
-
|
137 |
class WhisperModelLoader:
|
138 |
_instance = None
|
139 |
|
@@ -164,25 +91,22 @@ class WhisperModelLoader:
|
|
164 |
model = self.load()
|
165 |
return model.transcribe(audio_path)
|
166 |
|
167 |
-
# Initialize agents
|
168 |
try:
|
169 |
-
#
|
170 |
-
medical_data_extractor_model_loader = LazyModelLoader(
|
171 |
-
"facebook/bart-base", # Start with a smaller model
|
172 |
-
"text-generation",
|
173 |
-
fallback_model="facebook/bart-large-cnn"
|
174 |
-
)
|
175 |
-
summarization_model_loader = LazyModelLoader(
|
176 |
-
"Falconsai/medical_summarization", # ✅ Known working
|
177 |
-
"summarization",
|
178 |
-
fallback_model="Falconsai/medical_summarization"
|
179 |
-
)
|
180 |
-
|
181 |
-
# Initialize agents with lazy loading
|
182 |
text_extractor_agent = TextExtractorAgent()
|
183 |
phi_scrubber_agent = PHIScrubberAgent()
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
# Pass all agents and models to routes
|
188 |
agents = {
|
@@ -191,12 +115,15 @@ try:
|
|
191 |
"summarizer": summarizer_agent,
|
192 |
"medical_data_extractor": medical_data_extractor_agent,
|
193 |
"whisper_model": WhisperModelLoader.get_instance(),
|
194 |
-
"patient_summarizer":
|
|
|
195 |
}
|
196 |
|
197 |
from .api.routes import register_routes
|
198 |
register_routes(app, agents)
|
199 |
|
|
|
|
|
200 |
except Exception as e:
|
201 |
logging.error(f"Failed to initialize application: {str(e)}", exc_info=True)
|
202 |
raise
|
|
|
11 |
from .agents.medical_data_extractor import MedicalDataExtractorAgent
|
12 |
from .agents.medical_data_extractor import MedicalDocDataExtractorAgent
|
13 |
from .agents.patient_summary_agent import PatientSummarizerAgent
|
14 |
+
from .utils.model_manager import model_manager
|
15 |
import torch
|
16 |
|
|
|
17 |
# Load environment variables
|
18 |
load_dotenv()
|
19 |
|
|
|
50 |
|
51 |
# Set cache directories
|
52 |
CACHE_DIRS = {
|
|
|
53 |
'HF_HOME': '/tmp/huggingface',
|
54 |
'XDG_CACHE_HOME': '/tmp',
|
55 |
'TORCH_HOME': '/tmp/torch',
|
|
|
60 |
os.environ[env_var] = path
|
61 |
os.makedirs(path, exist_ok=True)
|
62 |
|
63 |
+
# WhisperModelLoader for audio transcription
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
class WhisperModelLoader:
|
65 |
_instance = None
|
66 |
|
|
|
91 |
model = self.load()
|
92 |
return model.transcribe(audio_path)
|
93 |
|
94 |
+
# Initialize agents with unified model manager
|
95 |
try:
|
96 |
+
# Initialize basic agents that don't require specific models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
text_extractor_agent = TextExtractorAgent()
|
98 |
phi_scrubber_agent = PHIScrubberAgent()
|
99 |
+
|
100 |
+
# Initialize model-dependent agents with unified model manager
|
101 |
+
# These will be loaded dynamically when needed
|
102 |
+
medical_data_extractor_agent = MedicalDataExtractorAgent(None) # Will be set dynamically
|
103 |
+
summarizer_agent = SummarizerAgent(None) # Will be set dynamically
|
104 |
+
|
105 |
+
# Initialize patient summarizer with unified model manager support
|
106 |
+
patient_summarizer_agent = PatientSummarizerAgent(
|
107 |
+
model_name="falconsai/medical_summarization",
|
108 |
+
model_type="summarization"
|
109 |
+
)
|
110 |
|
111 |
# Pass all agents and models to routes
|
112 |
agents = {
|
|
|
115 |
"summarizer": summarizer_agent,
|
116 |
"medical_data_extractor": medical_data_extractor_agent,
|
117 |
"whisper_model": WhisperModelLoader.get_instance(),
|
118 |
+
"patient_summarizer": patient_summarizer_agent,
|
119 |
+
"model_manager": model_manager, # Add unified model manager
|
120 |
}
|
121 |
|
122 |
from .api.routes import register_routes
|
123 |
register_routes(app, agents)
|
124 |
|
125 |
+
logging.info("Application initialized successfully with unified model manager")
|
126 |
+
|
127 |
except Exception as e:
|
128 |
logging.error(f"Failed to initialize application: {str(e)}", exc_info=True)
|
129 |
raise
|
ai_med_extract/utils/model_config.py
CHANGED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Model configuration for the unified model manager
|
3 |
+
Defines default models, fallback options, and model type mappings
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Default models for different tasks
|
7 |
+
DEFAULT_MODELS = {
|
8 |
+
"text-generation": {
|
9 |
+
"primary": "facebook/bart-base",
|
10 |
+
"fallback": "facebook/bart-large-cnn",
|
11 |
+
"description": "Text generation models for QA and medical data extraction"
|
12 |
+
},
|
13 |
+
"summarization": {
|
14 |
+
"primary": "Falconsai/medical_summarization",
|
15 |
+
"fallback": "facebook/bart-large-cnn",
|
16 |
+
"description": "Text summarization models for medical reports"
|
17 |
+
},
|
18 |
+
"ner": {
|
19 |
+
"primary": "dslim/bert-base-NER",
|
20 |
+
"fallback": "dslim/bert-base-NER",
|
21 |
+
"description": "Named Entity Recognition for medical entities"
|
22 |
+
},
|
23 |
+
"gguf": {
|
24 |
+
"primary": "microsoft/Phi-3-mini-4k-instruct-gguf",
|
25 |
+
"fallback": "microsoft/Phi-3-mini-4k-instruct-gguf",
|
26 |
+
"description": "GGUF models for patient summaries and medical tasks"
|
27 |
+
},
|
28 |
+
"openvino": {
|
29 |
+
"primary": "microsoft/Phi-3-mini-4k-instruct",
|
30 |
+
"fallback": "microsoft/Phi-3-mini-4k-instruct",
|
31 |
+
"description": "OpenVINO optimized models"
|
32 |
+
}
|
33 |
+
}
|
34 |
+
|
35 |
+
# Model type mappings for automatic detection
|
36 |
+
MODEL_TYPE_MAPPINGS = {
|
37 |
+
# GGUF models
|
38 |
+
".gguf": "gguf",
|
39 |
+
"gguf": "gguf",
|
40 |
+
|
41 |
+
# OpenVINO models
|
42 |
+
"openvino": "openvino",
|
43 |
+
"ov": "openvino",
|
44 |
+
|
45 |
+
# Transformers models
|
46 |
+
"text-generation": "text-generation",
|
47 |
+
"summarization": "summarization",
|
48 |
+
"ner": "ner",
|
49 |
+
"question-answering": "text-generation",
|
50 |
+
"translation": "text-generation"
|
51 |
+
}
|
52 |
+
|
53 |
+
# Memory-optimized models for Hugging Face Spaces
|
54 |
+
SPACES_OPTIMIZED_MODELS = {
|
55 |
+
"text-generation": "facebook/bart-base",
|
56 |
+
"summarization": "Falconsai/medical_summarization",
|
57 |
+
"ner": "dslim/bert-base-NER",
|
58 |
+
"gguf": "microsoft/Phi-3-mini-4k-instruct-gguf"
|
59 |
+
}
|
60 |
+
|
61 |
+
# Model validation rules
|
62 |
+
MODEL_VALIDATION_RULES = {
|
63 |
+
"text-generation": {
|
64 |
+
"min_tokens": 100,
|
65 |
+
"max_tokens": 2048,
|
66 |
+
"supported_formats": ["huggingface", "local"]
|
67 |
+
},
|
68 |
+
"summarization": {
|
69 |
+
"min_tokens": 50,
|
70 |
+
"max_tokens": 1024,
|
71 |
+
"supported_formats": ["huggingface", "local"]
|
72 |
+
},
|
73 |
+
"ner": {
|
74 |
+
"min_tokens": 50,
|
75 |
+
"max_tokens": 512,
|
76 |
+
"supported_formats": ["huggingface", "local"]
|
77 |
+
},
|
78 |
+
"gguf": {
|
79 |
+
"min_tokens": 100,
|
80 |
+
"max_tokens": 4096,
|
81 |
+
"supported_formats": ["huggingface", "local", "remote"]
|
82 |
+
},
|
83 |
+
"openvino": {
|
84 |
+
"min_tokens": 100,
|
85 |
+
"max_tokens": 2048,
|
86 |
+
"supported_formats": ["huggingface", "local"]
|
87 |
+
}
|
88 |
+
}
|
89 |
+
|
90 |
+
def get_default_model(model_type: str, use_spaces_optimized: bool = False) -> str:
|
91 |
+
"""Get the default model for a given type"""
|
92 |
+
if use_spaces_optimized and model_type in SPACES_OPTIMIZED_MODELS:
|
93 |
+
return SPACES_OPTIMIZED_MODELS[model_type]
|
94 |
+
|
95 |
+
if model_type in DEFAULT_MODELS:
|
96 |
+
return DEFAULT_MODELS[model_type]["primary"]
|
97 |
+
|
98 |
+
# Fallback to text-generation if type not found
|
99 |
+
return DEFAULT_MODELS["text-generation"]["primary"]
|
100 |
+
|
101 |
+
def get_fallback_model(model_type: str) -> str:
|
102 |
+
"""Get the fallback model for a given type"""
|
103 |
+
if model_type in DEFAULT_MODELS:
|
104 |
+
return DEFAULT_MODELS[model_type]["fallback"]
|
105 |
+
|
106 |
+
return DEFAULT_MODELS["text-generation"]["fallback"]
|
107 |
+
|
108 |
+
def detect_model_type(model_name: str) -> str:
|
109 |
+
"""Automatically detect model type from model name"""
|
110 |
+
model_name_lower = model_name.lower()
|
111 |
+
|
112 |
+
# Check for explicit type indicators
|
113 |
+
for indicator, model_type in MODEL_TYPE_MAPPINGS.items():
|
114 |
+
if indicator in model_name_lower:
|
115 |
+
return model_type
|
116 |
+
|
117 |
+
# Check file extensions
|
118 |
+
if model_name.endswith('.gguf'):
|
119 |
+
return "gguf"
|
120 |
+
|
121 |
+
# Default to text-generation for unknown types
|
122 |
+
return "text-generation"
|
123 |
+
|
124 |
+
def validate_model_config(model_name: str, model_type: str) -> dict:
|
125 |
+
"""Validate model configuration and return validation result"""
|
126 |
+
result = {
|
127 |
+
"valid": True,
|
128 |
+
"warnings": [],
|
129 |
+
"errors": [],
|
130 |
+
"recommendations": []
|
131 |
+
}
|
132 |
+
|
133 |
+
# Check if model type is supported
|
134 |
+
if model_type not in MODEL_VALIDATION_RULES:
|
135 |
+
result["valid"] = False
|
136 |
+
result["errors"].append(f"Unsupported model type: {model_type}")
|
137 |
+
return result
|
138 |
+
|
139 |
+
# Check model name format
|
140 |
+
if model_type == "gguf":
|
141 |
+
if not (model_name.endswith('.gguf') or '/' in model_name):
|
142 |
+
result["warnings"].append("GGUF model should have .gguf extension or be in repo/filename format")
|
143 |
+
|
144 |
+
# Check for memory optimization recommendations
|
145 |
+
if model_type in ["text-generation", "summarization"]:
|
146 |
+
if "large" in model_name.lower() or "xl" in model_name.lower():
|
147 |
+
result["warnings"].append("Large models may cause memory issues on limited resources")
|
148 |
+
result["recommendations"].append("Consider using a smaller model for better performance")
|
149 |
+
|
150 |
+
return result
|
151 |
+
|
152 |
+
def get_model_info(model_name: str, model_type: str) -> dict:
|
153 |
+
"""Get comprehensive information about a model configuration"""
|
154 |
+
validation = validate_model_config(model_name, model_type)
|
155 |
+
|
156 |
+
return {
|
157 |
+
"model_name": model_name,
|
158 |
+
"model_type": model_type,
|
159 |
+
"detected_type": detect_model_type(model_name),
|
160 |
+
"default_model": get_default_model(model_type),
|
161 |
+
"fallback_model": get_fallback_model(model_type),
|
162 |
+
"validation": validation,
|
163 |
+
"supported_formats": MODEL_VALIDATION_RULES.get(model_type, {}).get("supported_formats", []),
|
164 |
+
"description": DEFAULT_MODELS.get(model_type, {}).get("description", "Unknown model type")
|
165 |
+
}
|
ai_med_extract/utils/model_manager.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
from typing import Dict, Any, Optional, Union, Tuple
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
import time
|
7 |
+
|
8 |
+
# Configure logging
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
class BaseModelLoader(ABC):
|
13 |
+
"""Abstract base class for model loaders"""
|
14 |
+
|
15 |
+
@abstractmethod
|
16 |
+
def load(self) -> Any:
|
17 |
+
"""Load and return the model"""
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
22 |
+
"""Generate text from prompt"""
|
23 |
+
pass
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def get_model_info(self) -> Dict[str, Any]:
|
27 |
+
"""Get model information"""
|
28 |
+
pass
|
29 |
+
|
30 |
+
class TransformersModelLoader(BaseModelLoader):
|
31 |
+
"""Loader for Hugging Face Transformers models"""
|
32 |
+
|
33 |
+
def __init__(self, model_name: str, model_type: str, device: Optional[str] = None):
|
34 |
+
self.model_name = model_name
|
35 |
+
self.model_type = model_type
|
36 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
self._model = None
|
38 |
+
self._tokenizer = None
|
39 |
+
self._pipeline = None
|
40 |
+
|
41 |
+
def load(self):
|
42 |
+
if self._pipeline is None:
|
43 |
+
try:
|
44 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
45 |
+
|
46 |
+
logger.info(f"Loading Transformers model: {self.model_name} ({self.model_type})")
|
47 |
+
torch.cuda.empty_cache()
|
48 |
+
|
49 |
+
# Load tokenizer
|
50 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
51 |
+
self.model_name,
|
52 |
+
trust_remote_code=True,
|
53 |
+
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
54 |
+
)
|
55 |
+
|
56 |
+
if self.tokenizer.pad_token is None:
|
57 |
+
self._tokenizer.pad_token = self._tokenizer.eos_token
|
58 |
+
|
59 |
+
# Load model based on type
|
60 |
+
if self.model_type == "text-generation":
|
61 |
+
self._model = AutoModelForCausalLM.from_pretrained(
|
62 |
+
self.model_name,
|
63 |
+
trust_remote_code=True,
|
64 |
+
device_map="auto" if self.device == "cuda" else None,
|
65 |
+
low_cpu_mem_usage=True,
|
66 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
67 |
+
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
self._model = AutoModelForSeq2SeqLM.from_pretrained(
|
71 |
+
self.model_name,
|
72 |
+
trust_remote_code=True,
|
73 |
+
device_map="auto" if self.device == "cuda" else None,
|
74 |
+
low_cpu_mem_usage=True,
|
75 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
76 |
+
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
|
77 |
+
)
|
78 |
+
|
79 |
+
# Create pipeline
|
80 |
+
device_id = 0 if self.device == "cuda" else -1
|
81 |
+
self._pipeline = pipeline(
|
82 |
+
task=self.model_type,
|
83 |
+
model=self._model,
|
84 |
+
tokenizer=self._tokenizer,
|
85 |
+
device=device_id
|
86 |
+
)
|
87 |
+
|
88 |
+
logger.info(f"Transformers model loaded successfully: {self.model_name}")
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
logger.error(f"Failed to load Transformers model: {e}")
|
92 |
+
raise RuntimeError(f"Transformers model loading failed: {str(e)}")
|
93 |
+
|
94 |
+
return self._pipeline
|
95 |
+
|
96 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
97 |
+
pipeline = self.load()
|
98 |
+
|
99 |
+
try:
|
100 |
+
if self.model_type == "text-generation":
|
101 |
+
result = pipeline(
|
102 |
+
prompt,
|
103 |
+
max_new_tokens=kwargs.get('max_new_tokens', 512),
|
104 |
+
do_sample=kwargs.get('do_sample', False),
|
105 |
+
temperature=kwargs.get('temperature', 0.7),
|
106 |
+
pad_token_id=self._tokenizer.eos_token_id
|
107 |
+
)
|
108 |
+
if isinstance(result, list) and result:
|
109 |
+
return result[0].get('generated_text', '').replace(prompt, '').strip()
|
110 |
+
return str(result)
|
111 |
+
else:
|
112 |
+
result = pipeline(
|
113 |
+
prompt,
|
114 |
+
max_length=kwargs.get('max_length', 512),
|
115 |
+
min_length=kwargs.get('min_length', 50),
|
116 |
+
do_sample=kwargs.get('do_sample', False)
|
117 |
+
)
|
118 |
+
if isinstance(result, list) and result:
|
119 |
+
return result[0].get('summary_text', str(result[0]))
|
120 |
+
return str(result)
|
121 |
+
except Exception as e:
|
122 |
+
logger.error(f"Generation failed: {e}")
|
123 |
+
raise RuntimeError(f"Text generation failed: {str(e)}")
|
124 |
+
|
125 |
+
def get_model_info(self) -> Dict[str, Any]:
|
126 |
+
return {
|
127 |
+
"type": "transformers",
|
128 |
+
"model_name": self.model_name,
|
129 |
+
"model_type": self.model_type,
|
130 |
+
"device": self.device,
|
131 |
+
"loaded": self._pipeline is not None
|
132 |
+
}
|
133 |
+
|
134 |
+
@property
|
135 |
+
def tokenizer(self):
|
136 |
+
if self._tokenizer is None:
|
137 |
+
self.load()
|
138 |
+
return self._tokenizer
|
139 |
+
|
140 |
+
@property
|
141 |
+
def model(self):
|
142 |
+
if self._model is None:
|
143 |
+
self.load()
|
144 |
+
return self._model
|
145 |
+
|
146 |
+
class GGUFModelLoader(BaseModelLoader):
|
147 |
+
"""Loader for GGUF models using llama.cpp"""
|
148 |
+
|
149 |
+
def __init__(self, model_name: str, filename: Optional[str] = None, device: Optional[str] = None):
|
150 |
+
self.model_name = model_name
|
151 |
+
self.filename = filename
|
152 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
153 |
+
self._pipeline = None
|
154 |
+
|
155 |
+
def load(self):
|
156 |
+
if self._pipeline is None:
|
157 |
+
try:
|
158 |
+
from .model_loader_gguf import GGUFModelPipeline
|
159 |
+
|
160 |
+
logger.info(f"Loading GGUF model: {self.model_name}")
|
161 |
+
|
162 |
+
if self.filename:
|
163 |
+
self._pipeline = GGUFModelPipeline(self.model_name, self.filename)
|
164 |
+
else:
|
165 |
+
self._pipeline = GGUFModelPipeline(self.model_name)
|
166 |
+
|
167 |
+
logger.info(f"GGUF model loaded successfully: {self.model_name}")
|
168 |
+
|
169 |
+
except Exception as e:
|
170 |
+
logger.error(f"Failed to load GGUF model: {e}")
|
171 |
+
# Fallback to text-based response
|
172 |
+
from .model_loader_gguf import create_fallback_pipeline
|
173 |
+
self._pipeline = create_fallback_pipeline()
|
174 |
+
logger.warning(f"Using fallback pipeline for {self.model_name}")
|
175 |
+
|
176 |
+
return self._pipeline
|
177 |
+
|
178 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
179 |
+
pipeline = self.load()
|
180 |
+
|
181 |
+
try:
|
182 |
+
max_tokens = kwargs.get('max_tokens', 512)
|
183 |
+
temperature = kwargs.get('temperature', 0.7)
|
184 |
+
top_p = kwargs.get('top_p', 0.95)
|
185 |
+
|
186 |
+
if hasattr(pipeline, 'generate_full_summary'):
|
187 |
+
return pipeline.generate_full_summary(
|
188 |
+
prompt,
|
189 |
+
max_tokens=max_tokens,
|
190 |
+
max_loops=kwargs.get('max_loops', 1)
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
return pipeline.generate(
|
194 |
+
prompt,
|
195 |
+
max_tokens=max_tokens,
|
196 |
+
temperature=temperature,
|
197 |
+
top_p=top_p
|
198 |
+
)
|
199 |
+
except Exception as e:
|
200 |
+
logger.error(f"GGUF generation failed: {e}")
|
201 |
+
raise RuntimeError(f"GGUF generation failed: {str(e)}")
|
202 |
+
|
203 |
+
def get_model_info(self) -> Dict[str, Any]:
|
204 |
+
return {
|
205 |
+
"type": "gguf",
|
206 |
+
"model_name": self.model_name,
|
207 |
+
"filename": self.filename,
|
208 |
+
"device": self.device,
|
209 |
+
"loaded": self._pipeline is not None
|
210 |
+
}
|
211 |
+
|
212 |
+
class OpenVINOModelLoader(BaseModelLoader):
|
213 |
+
"""Loader for OpenVINO models"""
|
214 |
+
|
215 |
+
def __init__(self, model_name: str, device: Optional[str] = None):
|
216 |
+
self.model_name = model_name
|
217 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
218 |
+
self._pipeline = None
|
219 |
+
|
220 |
+
def load(self):
|
221 |
+
if self._pipeline is None:
|
222 |
+
try:
|
223 |
+
from .model_loader_spaces import get_openvino_pipeline
|
224 |
+
|
225 |
+
logger.info(f"Loading OpenVINO model: {self.model_name}")
|
226 |
+
self._pipeline = get_openvino_pipeline(self.model_name)
|
227 |
+
logger.info(f"OpenVINO model loaded successfully: {self.model_name}")
|
228 |
+
|
229 |
+
except Exception as e:
|
230 |
+
logger.error(f"Failed to load OpenVINO model: {e}")
|
231 |
+
raise RuntimeError(f"OpenVINO model loading failed: {str(e)}")
|
232 |
+
|
233 |
+
return self._pipeline
|
234 |
+
|
235 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
236 |
+
pipeline = self.load()
|
237 |
+
|
238 |
+
try:
|
239 |
+
# OpenVINO models typically use the same interface as transformers
|
240 |
+
inputs = pipeline.tokenizer([prompt], return_tensors="pt")
|
241 |
+
outputs = pipeline.model.generate(
|
242 |
+
**inputs,
|
243 |
+
max_new_tokens=kwargs.get('max_new_tokens', 500),
|
244 |
+
do_sample=False,
|
245 |
+
pad_token_id=pipeline.tokenizer.eos_token_id or 32000
|
246 |
+
)
|
247 |
+
return pipeline.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
248 |
+
except Exception as e:
|
249 |
+
logger.error(f"OpenVINO generation failed: {e}")
|
250 |
+
raise RuntimeError(f"OpenVINO generation failed: {str(e)}")
|
251 |
+
|
252 |
+
def get_model_info(self) -> Dict[str, Any]:
|
253 |
+
return {
|
254 |
+
"type": "openvino",
|
255 |
+
"model_name": self.model_name,
|
256 |
+
"device": self.device,
|
257 |
+
"loaded": self._pipeline is not None
|
258 |
+
}
|
259 |
+
|
260 |
+
class UnifiedModelManager:
|
261 |
+
"""Unified model manager that can handle any model type"""
|
262 |
+
|
263 |
+
def __init__(self):
|
264 |
+
self._model_cache: Dict[str, BaseModelLoader] = {}
|
265 |
+
self._fallback_models = {
|
266 |
+
"text-generation": "facebook/bart-base",
|
267 |
+
"summarization": "Falconsai/medical_summarization",
|
268 |
+
"ner": "dslim/bert-base-NER",
|
269 |
+
"gguf": "microsoft/Phi-3-mini-4k-instruct-gguf"
|
270 |
+
}
|
271 |
+
|
272 |
+
def get_model_loader(
|
273 |
+
self,
|
274 |
+
model_name: str,
|
275 |
+
model_type: str,
|
276 |
+
filename: Optional[str] = None,
|
277 |
+
force_reload: bool = False
|
278 |
+
) -> BaseModelLoader:
|
279 |
+
"""
|
280 |
+
Get a model loader for the specified model and type
|
281 |
+
|
282 |
+
Args:
|
283 |
+
model_name: Name or path of the model
|
284 |
+
model_type: Type of model (text-generation, summarization, ner, gguf, openvino)
|
285 |
+
filename: Optional filename for GGUF models
|
286 |
+
force_reload: Force reload the model even if cached
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
BaseModelLoader instance
|
290 |
+
"""
|
291 |
+
cache_key = f"{model_name}:{model_type}:{filename or ''}"
|
292 |
+
|
293 |
+
if not force_reload and cache_key in self._model_cache:
|
294 |
+
return self._model_cache[cache_key]
|
295 |
+
|
296 |
+
try:
|
297 |
+
# Determine loader type and create appropriate loader
|
298 |
+
if model_type == "gguf":
|
299 |
+
loader = GGUFModelLoader(model_name, filename)
|
300 |
+
elif model_type == "openvino":
|
301 |
+
loader = OpenVINOModelLoader(model_name)
|
302 |
+
else:
|
303 |
+
# Default to transformers for text-generation, summarization, ner, etc.
|
304 |
+
loader = TransformersModelLoader(model_name, model_type)
|
305 |
+
|
306 |
+
# Test load the model
|
307 |
+
loader.load()
|
308 |
+
|
309 |
+
# Cache the loader
|
310 |
+
self._model_cache[cache_key] = loader
|
311 |
+
|
312 |
+
logger.info(f"Model loader created successfully: {model_name} ({model_type})")
|
313 |
+
return loader
|
314 |
+
|
315 |
+
except Exception as e:
|
316 |
+
logger.error(f"Failed to create model loader for {model_name} ({model_type}): {e}")
|
317 |
+
|
318 |
+
# Try fallback model
|
319 |
+
fallback_name = self._fallback_models.get(model_type)
|
320 |
+
if fallback_name and fallback_name != model_name:
|
321 |
+
logger.warning(f"Trying fallback model: {fallback_name}")
|
322 |
+
try:
|
323 |
+
if model_type == "gguf":
|
324 |
+
loader = GGUFModelLoader(fallback_name)
|
325 |
+
elif model_type == "openvino":
|
326 |
+
loader = OpenVINOModelLoader(fallback_name)
|
327 |
+
else:
|
328 |
+
loader = TransformersModelLoader(fallback_name, model_type)
|
329 |
+
|
330 |
+
loader.load()
|
331 |
+
self._model_cache[cache_key] = loader
|
332 |
+
logger.info(f"Fallback model loaded successfully: {fallback_name}")
|
333 |
+
return loader
|
334 |
+
|
335 |
+
except Exception as fallback_error:
|
336 |
+
logger.error(f"Fallback model also failed: {fallback_error}")
|
337 |
+
|
338 |
+
# Create a basic fallback
|
339 |
+
from .model_loader_gguf import create_fallback_pipeline
|
340 |
+
|
341 |
+
class FallbackLoader(BaseModelLoader):
|
342 |
+
def __init__(self, model_name: str, model_type: str):
|
343 |
+
self.model_name = model_name
|
344 |
+
self.model_type = model_type
|
345 |
+
self._pipeline = create_fallback_pipeline()
|
346 |
+
|
347 |
+
def load(self):
|
348 |
+
return self._pipeline
|
349 |
+
|
350 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
351 |
+
return self._pipeline.generate(prompt, **kwargs)
|
352 |
+
|
353 |
+
def get_model_info(self) -> Dict[str, Any]:
|
354 |
+
return {
|
355 |
+
"type": "fallback",
|
356 |
+
"model_name": self.model_name,
|
357 |
+
"model_type": self.model_type,
|
358 |
+
"loaded": True
|
359 |
+
}
|
360 |
+
|
361 |
+
fallback_loader = FallbackLoader(model_name, model_type)
|
362 |
+
self._model_cache[cache_key] = fallback_loader
|
363 |
+
return fallback_loader
|
364 |
+
|
365 |
+
def generate_text(
|
366 |
+
self,
|
367 |
+
model_name: str,
|
368 |
+
model_type: str,
|
369 |
+
prompt: str,
|
370 |
+
filename: Optional[str] = None,
|
371 |
+
**kwargs
|
372 |
+
) -> str:
|
373 |
+
"""
|
374 |
+
Generate text using the specified model
|
375 |
+
|
376 |
+
Args:
|
377 |
+
model_name: Name or path of the model
|
378 |
+
model_type: Type of model
|
379 |
+
prompt: Input prompt
|
380 |
+
filename: Optional filename for GGUF models
|
381 |
+
**kwargs: Additional generation parameters
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
Generated text
|
385 |
+
"""
|
386 |
+
loader = self.get_model_loader(model_name, model_type, filename)
|
387 |
+
return loader.generate(prompt, **kwargs)
|
388 |
+
|
389 |
+
def get_model_info(self, model_name: str, model_type: str, filename: Optional[str] = None) -> Dict[str, Any]:
|
390 |
+
"""Get information about a specific model"""
|
391 |
+
loader = self.get_model_loader(model_name, model_type, filename)
|
392 |
+
return loader.get_model_info()
|
393 |
+
|
394 |
+
def clear_cache(self):
|
395 |
+
"""Clear the model cache"""
|
396 |
+
self._model_cache.clear()
|
397 |
+
torch.cuda.empty_cache()
|
398 |
+
logger.info("Model cache cleared")
|
399 |
+
|
400 |
+
def list_loaded_models(self) -> Dict[str, Dict[str, Any]]:
|
401 |
+
"""List all loaded models and their information"""
|
402 |
+
return {
|
403 |
+
cache_key: loader.get_model_info()
|
404 |
+
for cache_key, loader in self._model_cache.items()
|
405 |
+
}
|
406 |
+
|
407 |
+
# Global instance
|
408 |
+
model_manager = UnifiedModelManager()
|
test_refactored_system.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test script for the refactored HNTAI system
|
4 |
+
Demonstrates the new unified model manager and dynamic model loading capabilities
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import time
|
10 |
+
import logging
|
11 |
+
import requests
|
12 |
+
import json
|
13 |
+
|
14 |
+
# Configure logging
|
15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
# Set environment variables for testing
|
19 |
+
os.environ['HF_HOME'] = '/tmp/huggingface'
|
20 |
+
os.environ['GGUF_N_THREADS'] = '2'
|
21 |
+
os.environ['GGUF_N_BATCH'] = '64'
|
22 |
+
|
23 |
+
def test_model_manager():
|
24 |
+
"""Test the unified model manager"""
|
25 |
+
logger.info("Testing Unified Model Manager...")
|
26 |
+
|
27 |
+
try:
|
28 |
+
from ai_med_extract.utils.model_manager import model_manager
|
29 |
+
|
30 |
+
# Test 1: Load a transformers model
|
31 |
+
logger.info("Test 1: Loading Transformers model...")
|
32 |
+
loader = model_manager.get_model_loader("facebook/bart-base", "text-generation")
|
33 |
+
result = loader.generate("Hello, how are you?", max_new_tokens=50)
|
34 |
+
logger.info(f"✅ Transformers model test passed: {len(result)} characters generated")
|
35 |
+
|
36 |
+
# Test 2: Load a GGUF model
|
37 |
+
logger.info("Test 2: Loading GGUF model...")
|
38 |
+
try:
|
39 |
+
gguf_loader = model_manager.get_model_loader(
|
40 |
+
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
41 |
+
"gguf"
|
42 |
+
)
|
43 |
+
result = gguf_loader.generate("Generate a brief medical summary: Patient has fever and cough.", max_tokens=100)
|
44 |
+
logger.info(f"✅ GGUF model test passed: {len(result)} characters generated")
|
45 |
+
except Exception as e:
|
46 |
+
logger.warning(f"⚠️ GGUF model test failed (this is expected if model not available): {e}")
|
47 |
+
|
48 |
+
# Test 3: Test fallback mechanism
|
49 |
+
logger.info("Test 3: Testing fallback mechanism...")
|
50 |
+
try:
|
51 |
+
fallback_loader = model_manager.get_model_loader("invalid/model", "text-generation")
|
52 |
+
result = fallback_loader.generate("Test prompt")
|
53 |
+
logger.info(f"✅ Fallback mechanism test passed: {len(result)} characters generated")
|
54 |
+
except Exception as e:
|
55 |
+
logger.error(f"❌ Fallback mechanism test failed: {e}")
|
56 |
+
return False
|
57 |
+
|
58 |
+
logger.info("🎉 All model manager tests passed!")
|
59 |
+
return True
|
60 |
+
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"❌ Model manager test failed: {e}")
|
63 |
+
return False
|
64 |
+
|
65 |
+
def test_patient_summarizer():
|
66 |
+
"""Test the refactored patient summarizer agent"""
|
67 |
+
logger.info("Testing Patient Summarizer Agent...")
|
68 |
+
|
69 |
+
try:
|
70 |
+
from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent
|
71 |
+
|
72 |
+
# Test with different model types
|
73 |
+
test_cases = [
|
74 |
+
{
|
75 |
+
"name": "Transformers Summarization",
|
76 |
+
"model_name": "Falconsai/medical_summarization",
|
77 |
+
"model_type": "summarization"
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"name": "GGUF Model",
|
81 |
+
"model_name": "microsoft/Phi-3-mini-4k-instruct-gguf",
|
82 |
+
"model_type": "gguf"
|
83 |
+
}
|
84 |
+
]
|
85 |
+
|
86 |
+
for test_case in test_cases:
|
87 |
+
logger.info(f"Testing: {test_case['name']}")
|
88 |
+
try:
|
89 |
+
agent = PatientSummarizerAgent(
|
90 |
+
model_name=test_case["model_name"],
|
91 |
+
model_type=test_case["model_type"]
|
92 |
+
)
|
93 |
+
|
94 |
+
# Test with sample patient data
|
95 |
+
sample_data = {
|
96 |
+
"result": {
|
97 |
+
"patientname": "John Doe",
|
98 |
+
"patientnumber": "12345",
|
99 |
+
"agey": "45",
|
100 |
+
"gender": "Male",
|
101 |
+
"allergies": ["Penicillin"],
|
102 |
+
"social_history": "Non-smoker, occasional alcohol",
|
103 |
+
"past_medical_history": ["Hypertension", "Diabetes"],
|
104 |
+
"encounters": [
|
105 |
+
{
|
106 |
+
"visit_date": "2024-01-15",
|
107 |
+
"chief_complaint": "Chest pain",
|
108 |
+
"symptoms": "Sharp chest pain, shortness of breath",
|
109 |
+
"diagnosis": ["Angina", "Hypertension"],
|
110 |
+
"dr_notes": "Patient reports chest pain for 2 days",
|
111 |
+
"vitals": {"BP": "140/90", "HR": "85", "SpO2": "98%"},
|
112 |
+
"medications": ["Aspirin", "Metoprolol"],
|
113 |
+
"treatment": "Prescribed medications, follow-up in 1 week"
|
114 |
+
}
|
115 |
+
]
|
116 |
+
}
|
117 |
+
}
|
118 |
+
|
119 |
+
summary = agent.generate_clinical_summary(sample_data)
|
120 |
+
logger.info(f"✅ {test_case['name']} test passed: {len(summary)} characters generated")
|
121 |
+
|
122 |
+
except Exception as e:
|
123 |
+
logger.warning(f"⚠️ {test_case['name']} test failed (this may be expected): {e}")
|
124 |
+
|
125 |
+
logger.info("🎉 Patient summarizer tests completed!")
|
126 |
+
return True
|
127 |
+
|
128 |
+
except Exception as e:
|
129 |
+
logger.error(f"❌ Patient summarizer test failed: {e}")
|
130 |
+
return False
|
131 |
+
|
132 |
+
def test_model_config():
|
133 |
+
"""Test the model configuration system"""
|
134 |
+
logger.info("Testing Model Configuration...")
|
135 |
+
|
136 |
+
try:
|
137 |
+
from ai_med_extract.utils.model_config import (
|
138 |
+
detect_model_type,
|
139 |
+
validate_model_config,
|
140 |
+
get_model_info,
|
141 |
+
get_default_model
|
142 |
+
)
|
143 |
+
|
144 |
+
# Test model type detection
|
145 |
+
test_models = [
|
146 |
+
("facebook/bart-base", "text-generation"),
|
147 |
+
("Falconsai/medical_summarization", "summarization"),
|
148 |
+
("microsoft/Phi-3-mini-4k-instruct-gguf", "gguf"),
|
149 |
+
("model.gguf", "gguf"),
|
150 |
+
("unknown/model", "text-generation") # Default fallback
|
151 |
+
]
|
152 |
+
|
153 |
+
for model_name, expected_type in test_models:
|
154 |
+
detected_type = detect_model_type(model_name)
|
155 |
+
if detected_type == expected_type:
|
156 |
+
logger.info(f"✅ Model type detection correct: {model_name} -> {detected_type}")
|
157 |
+
else:
|
158 |
+
logger.warning(f"⚠️ Model type detection mismatch: {model_name} -> {detected_type} (expected {expected_type})")
|
159 |
+
|
160 |
+
# Test model validation
|
161 |
+
validation = validate_model_config("microsoft/Phi-3-mini-4k-instruct-gguf", "gguf")
|
162 |
+
if validation["valid"]:
|
163 |
+
logger.info("✅ Model validation test passed")
|
164 |
+
else:
|
165 |
+
logger.warning(f"⚠️ Model validation warnings: {validation['warnings']}")
|
166 |
+
|
167 |
+
# Test default models
|
168 |
+
default_summary = get_default_model("summarization")
|
169 |
+
logger.info(f"✅ Default summarization model: {default_summary}")
|
170 |
+
|
171 |
+
logger.info("🎉 Model configuration tests completed!")
|
172 |
+
return True
|
173 |
+
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"❌ Model configuration test failed: {e}")
|
176 |
+
return False
|
177 |
+
|
178 |
+
def test_api_endpoints():
|
179 |
+
"""Test the new API endpoints (if server is running)"""
|
180 |
+
logger.info("Testing API Endpoints...")
|
181 |
+
|
182 |
+
base_url = "http://localhost:7860" # Adjust if different
|
183 |
+
|
184 |
+
try:
|
185 |
+
# Test health check
|
186 |
+
response = requests.get(f"{base_url}/api/models/health", timeout=10)
|
187 |
+
if response.status_code == 200:
|
188 |
+
health_data = response.json()
|
189 |
+
logger.info(f"✅ Health check passed: {health_data.get('status', 'unknown')}")
|
190 |
+
logger.info(f" Loaded models: {health_data.get('loaded_models_count', 0)}")
|
191 |
+
if health_data.get('gpu_info', {}).get('available'):
|
192 |
+
logger.info(f" GPU memory: {health_data['gpu_info']['memory_allocated']}")
|
193 |
+
else:
|
194 |
+
logger.warning(f"⚠️ Health check failed with status {response.status_code}")
|
195 |
+
return False
|
196 |
+
|
197 |
+
# Test model info
|
198 |
+
response = requests.get(f"{base_url}/api/models/info", timeout=10)
|
199 |
+
if response.status_code == 200:
|
200 |
+
info_data = response.json()
|
201 |
+
logger.info(f"✅ Model info endpoint working: {info_data.get('total_models', 0)} models loaded")
|
202 |
+
else:
|
203 |
+
logger.warning(f"⚠️ Model info endpoint failed with status {response.status_code}")
|
204 |
+
|
205 |
+
# Test default models
|
206 |
+
response = requests.get(f"{base_url}/api/models/defaults", timeout=10)
|
207 |
+
if response.status_code == 200:
|
208 |
+
defaults_data = response.json()
|
209 |
+
logger.info(f"✅ Default models endpoint working: {len(defaults_data.get('default_models', {}))} model types available")
|
210 |
+
else:
|
211 |
+
logger.warning(f"⚠️ Default models endpoint failed with status {response.status_code}")
|
212 |
+
|
213 |
+
logger.info("🎉 API endpoint tests completed!")
|
214 |
+
return True
|
215 |
+
|
216 |
+
except requests.exceptions.ConnectionError:
|
217 |
+
logger.warning("⚠️ Server not running, skipping API tests")
|
218 |
+
return True
|
219 |
+
except Exception as e:
|
220 |
+
logger.error(f"❌ API endpoint test failed: {e}")
|
221 |
+
return False
|
222 |
+
|
223 |
+
def test_memory_optimization():
|
224 |
+
"""Test memory optimization features"""
|
225 |
+
logger.info("Testing Memory Optimization...")
|
226 |
+
|
227 |
+
try:
|
228 |
+
import torch
|
229 |
+
|
230 |
+
# Check if we're in Hugging Face Spaces
|
231 |
+
is_hf_space = os.environ.get('SPACE_ID') is not None
|
232 |
+
|
233 |
+
if is_hf_space:
|
234 |
+
logger.info("🔄 Detected Hugging Face Space - testing memory optimization...")
|
235 |
+
|
236 |
+
# Test with smaller models
|
237 |
+
from ai_med_extract.utils.model_manager import model_manager
|
238 |
+
|
239 |
+
loader = model_manager.get_model_loader("facebook/bart-base", "text-generation")
|
240 |
+
result = loader.generate("Test prompt for memory optimization", max_new_tokens=50)
|
241 |
+
|
242 |
+
logger.info(f"✅ Memory optimization test passed: {len(result)} characters generated")
|
243 |
+
else:
|
244 |
+
logger.info("🔄 Local environment detected - memory optimization not applicable")
|
245 |
+
|
246 |
+
# Test cache clearing
|
247 |
+
from ai_med_extract.utils.model_manager import model_manager
|
248 |
+
model_manager.clear_cache()
|
249 |
+
logger.info("✅ Cache clearing test passed")
|
250 |
+
|
251 |
+
return True
|
252 |
+
|
253 |
+
except Exception as e:
|
254 |
+
logger.error(f"❌ Memory optimization test failed: {e}")
|
255 |
+
return False
|
256 |
+
|
257 |
+
def main():
|
258 |
+
"""Main test function"""
|
259 |
+
logger.info("🚀 Starting HNTAI Refactored System Tests...")
|
260 |
+
logger.info("=" * 60)
|
261 |
+
|
262 |
+
test_results = []
|
263 |
+
|
264 |
+
# Run all tests
|
265 |
+
tests = [
|
266 |
+
("Model Manager", test_model_manager),
|
267 |
+
("Patient Summarizer", test_patient_summarizer),
|
268 |
+
("Model Configuration", test_model_config),
|
269 |
+
("API Endpoints", test_api_endpoints),
|
270 |
+
("Memory Optimization", test_memory_optimization)
|
271 |
+
]
|
272 |
+
|
273 |
+
for test_name, test_func in tests:
|
274 |
+
logger.info(f"\n🧪 Running {test_name} Test...")
|
275 |
+
try:
|
276 |
+
result = test_func()
|
277 |
+
test_results.append((test_name, result))
|
278 |
+
except Exception as e:
|
279 |
+
logger.error(f"❌ {test_name} test crashed: {e}")
|
280 |
+
test_results.append((test_name, False))
|
281 |
+
|
282 |
+
# Summary
|
283 |
+
logger.info("\n" + "=" * 60)
|
284 |
+
logger.info("📊 TEST SUMMARY")
|
285 |
+
logger.info("=" * 60)
|
286 |
+
|
287 |
+
passed = 0
|
288 |
+
total = len(test_results)
|
289 |
+
|
290 |
+
for test_name, result in test_results:
|
291 |
+
status = "✅ PASS" if result else "❌ FAIL"
|
292 |
+
logger.info(f"{test_name}: {status}")
|
293 |
+
if result:
|
294 |
+
passed += 1
|
295 |
+
|
296 |
+
logger.info(f"\nOverall: {passed}/{total} tests passed")
|
297 |
+
|
298 |
+
if passed == total:
|
299 |
+
logger.info("🎉 All tests passed! The refactored system is working correctly.")
|
300 |
+
logger.info("✨ You can now use any model name and type, including GGUF models!")
|
301 |
+
else:
|
302 |
+
logger.warning(f"⚠️ {total - passed} tests failed. Check the logs above for details.")
|
303 |
+
|
304 |
+
# Recommendations
|
305 |
+
logger.info("\n💡 RECOMMENDATIONS:")
|
306 |
+
if passed >= total * 0.8:
|
307 |
+
logger.info("✅ System is ready for production use")
|
308 |
+
logger.info("✅ GGUF models are supported for patient summaries")
|
309 |
+
logger.info("✅ Dynamic model loading is working")
|
310 |
+
elif passed >= total * 0.6:
|
311 |
+
logger.info("⚠️ System is mostly working but has some issues")
|
312 |
+
logger.info("⚠️ Check failed tests and fix issues")
|
313 |
+
else:
|
314 |
+
logger.error("❌ System has significant issues")
|
315 |
+
logger.error("❌ Review and fix failed tests before use")
|
316 |
+
|
317 |
+
return passed == total
|
318 |
+
|
319 |
+
if __name__ == "__main__":
|
320 |
+
success = main()
|
321 |
+
sys.exit(0 if success else 1)
|