dev-3 commited on
Commit
f93d309
·
1 Parent(s): b9c4cf8
.dockerignore CHANGED
@@ -1,3 +1,50 @@
1
  # This file tells Hugging Face Spaces to use Docker
2
  # and exposes the correct port for Flask/Gradio/FastAPI
3
  # No further config needed if Dockerfile is present
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # This file tells Hugging Face Spaces to use Docker
2
  # and exposes the correct port for Flask/Gradio/FastAPI
3
  # No further config needed if Dockerfile is present
4
+ __pycache__/
5
+ *.pyc
6
+ *.pyo
7
+ *.pyd
8
+ *.so
9
+ *.egg-info/
10
+ *.egg
11
+
12
+ # Exclude datasets, model weights, and large files
13
+ *.pt
14
+ *.pth
15
+ *.ckpt
16
+ *.h5
17
+ *.onnx
18
+ *.npz
19
+ *.npy
20
+ *.tar.gz
21
+ *.zip
22
+ *.tar
23
+ *.gz
24
+ *.bz2
25
+ *.7z
26
+ *.rar
27
+
28
+ # Exclude logs and outputs
29
+ *.log
30
+ *.out
31
+ *.tmp
32
+ *.swp
33
+
34
+ # Exclude Jupyter notebooks (if not needed)
35
+ *.ipynb
36
+
37
+ # Exclude local environment files
38
+ .env
39
+ .venv/
40
+ venv/
41
+
42
+ # Exclude OS files
43
+ .DS_Store
44
+ Thumbs.db
45
+
46
+ # Exclude other unnecessary folders/files
47
+ node_modules/
48
+ datasets/
49
+ models/
50
+ outputs/
DEPLOYMENT.md CHANGED
@@ -1,13 +1,13 @@
1
  # Hugging Face Spaces Docker deployment instructions
2
 
3
- # 1. Make sure your Dockerfile exposes port 5000 and runs your app on 0.0.0.0:5000
4
- # 2. Your Flask app should listen on host='0.0.0.0' and port=5000
5
  # 3. requirements.txt should include all dependencies
6
  # 4. .huggingface.yaml with 'runtime: docker' is present
7
  # 5. .dockerignore and .gitignore are present
8
 
9
  # To test locally:
10
  # docker build -t hntai-app .
11
- # docker run -p 5000:5000 hntai-app
12
 
13
- # Your app will be available at http://localhost:5000
 
1
  # Hugging Face Spaces Docker deployment instructions
2
 
3
+ # 1. Make sure your Dockerfile exposes port 7860 and runs your app on 0.0.0.0:7860
4
+ # 2. Your Flask app should listen on host='0.0.0.0' and port=7860
5
  # 3. requirements.txt should include all dependencies
6
  # 4. .huggingface.yaml with 'runtime: docker' is present
7
  # 5. .dockerignore and .gitignore are present
8
 
9
  # To test locally:
10
  # docker build -t hntai-app .
11
+ # docker run -p 7860:7860 hntai-app
12
 
13
+ # Your app will be available at http://localhost:7860
Dockerfile CHANGED
@@ -1,7 +1,6 @@
1
- # Use a lightweight Python base image
2
  FROM python:3.10-slim
3
 
4
- # Install system dependencies
5
  RUN apt-get update && apt-get install -y \
6
  build-essential \
7
  pkg-config \
@@ -17,25 +16,28 @@ RUN apt-get update && apt-get install -y \
17
  libgl1 \
18
  && rm -rf /var/lib/apt/lists/*
19
 
20
-
21
- # Set the working directory
22
  WORKDIR /app
23
 
24
  # Create uploads directory and set permissions
25
  RUN mkdir -p /app/uploads && chmod 777 /app/uploads
26
 
27
  # Copy only dependency files first for better caching
28
- COPY requirements.txt .
29
 
30
  # Install pip and dependencies
31
  RUN pip install --upgrade pip \
32
- && pip install -r requirements.txt --no-cache-dir --retries 10 --timeout 120
33
-
34
- # Copy rest of your code (this is after deps so doesn't bust cache)
 
 
 
 
 
35
  COPY . .
36
 
37
- # Expose port 5000 (required by HF Spaces)
38
- EXPOSE 5000
39
 
40
  # Run the Flask app
41
- CMD ["gunicorn", "-b", "0.0.0.0:5000", "ai_med_extract.app:app"]
 
 
1
  FROM python:3.10-slim
2
 
3
+ # Install system dependencies and build tools
4
  RUN apt-get update && apt-get install -y \
5
  build-essential \
6
  pkg-config \
 
16
  libgl1 \
17
  && rm -rf /var/lib/apt/lists/*
18
 
 
 
19
  WORKDIR /app
20
 
21
  # Create uploads directory and set permissions
22
  RUN mkdir -p /app/uploads && chmod 777 /app/uploads
23
 
24
  # Copy only dependency files first for better caching
25
+ COPY requirements.txt .
26
 
27
  # Install pip and dependencies
28
  RUN pip install --upgrade pip \
29
+ && pip install -r requirements.txt --no-cache-dir \
30
+ # Remove build tools and clean up to reduce image size
31
+ && apt-get remove -y build-essential pkg-config libsystemd-dev libcairo2-dev \
32
+ && apt-get autoremove -y \
33
+ && apt-get clean \
34
+ && rm -rf /var/lib/apt/lists/*
35
+
36
+ # Copy the rest of your code
37
  COPY . .
38
 
39
+ # Expose port 7860 (required by HF Spaces)
40
+ EXPOSE 7860
41
 
42
  # Run the Flask app
43
+ CMD ["gunicorn", "-b", "0.0.0.0:7860", "ai_med_extract.app:app"]
ai_med_extract/__main__.py CHANGED
@@ -2,4 +2,4 @@ from .app import app
2
 
3
  # Entrypoint for running the app as a module
4
  if __name__ == "__main__":
5
- app.run(host="0.0.0.0", port=5000, debug=True)
 
2
 
3
  # Entrypoint for running the app as a module
4
  if __name__ == "__main__":
5
+ app.run(host="0.0.0.0", port=7860, debug=True)
ai_med_extract/app.py CHANGED
@@ -56,4 +56,4 @@ from .api.routes import register_routes
56
  register_routes(app, agents)
57
 
58
  if __name__ == "__main__":
59
- app.run(host="0.0.0.0", port=5000, debug=True)
 
56
  register_routes(app, agents)
57
 
58
  if __name__ == "__main__":
59
+ app.run(host="0.0.0.0", port=7860, debug=True)
combined1.py DELETED
@@ -1,880 +0,0 @@
1
- import json
2
- import os
3
- import re
4
- import logging
5
- from dotenv import load_dotenv
6
- from flask import Flask, request, jsonify, abort
7
- from werkzeug.utils import secure_filename
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
9
- import pytesseract
10
- import cv2
11
- import pdfplumber
12
- import pandas as pd
13
- from PIL import Image
14
- from docx import Document
15
- from flask_cors import CORS
16
- from flask_executor import Executor
17
- from sentence_transformers import SentenceTransformer
18
- import faiss
19
- import whisper
20
- from PyPDF2 import PdfReader
21
- from pdf2image import convert_from_path
22
- from concurrent.futures import ThreadPoolExecutor
23
- import tempfile
24
- import tensorflow.keras.layers as KL # Instead of keras.layers as KL
25
- import numpy as np
26
-
27
- # Load environment variables
28
- load_dotenv()
29
-
30
- # Set Tesseract OCR Path
31
- pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
32
-
33
- # Initialize Flask app
34
- app = Flask(__name__)
35
- CORS(app)
36
-
37
- # Configure logging
38
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
39
-
40
- # Configure upload directory and max file size
41
- UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads'))
42
- os.makedirs(UPLOAD_DIR, exist_ok=True)
43
- app.config['UPLOAD_FOLDER'] = UPLOAD_DIR
44
- app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max file size
45
-
46
- # Initialize Flask-Executor for asynchronous tasks
47
- executor = Executor(app)
48
- whisper_model = whisper.load_model("tiny")
49
-
50
- # Allowed file extensions
51
- ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'flac'}
52
- ALLOWED_DOCUMENT_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'docx', 'xlsx', 'xls'}
53
-
54
- UPLOAD_FOLDER = 'Uploads'
55
- ALLOWED_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'svg', 'docx', 'doc'}
56
- app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
57
-
58
- # Set file size limits
59
- MAX_SIZE_PDF_DOCS = 1 * 1024 * 1024 * 1024 # 1GB
60
- MAX_SIZE_IMAGES = 500 * 1024 * 1024 # 500MB
61
-
62
- # Lazy model loading to save resources
63
- class LazyModelLoader:
64
- def __init__(self, model_name, task, tokenizer=None):
65
- self.model_name = model_name
66
- self.task = task
67
- self.tokenizer = tokenizer
68
- self._model = None
69
-
70
- def load(self):
71
- """Load the model if not already loaded."""
72
- if self._model is None:
73
- logging.info(f"Loading model: {self.model_name}")
74
- if self.task == "text-generation":
75
- self._model = AutoModelForCausalLM.from_pretrained(
76
- self.model_name, device_map="auto", torch_dtype="auto"
77
- )
78
- self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, legacy=False)
79
- if self._model.generation_config.pad_token_id is None or self._model.generation_config.pad_token_id < 0:
80
- if self._tokenizer.eos_token_id is not None:
81
- self._model.generation_config.pad_token_id = self._tokenizer.eos_token_id
82
- logging.info(f"Set pad_token_id to {self._tokenizer.eos_token_id}")
83
- else:
84
- logging.warning("No valid eos_token_id found. Setting pad_token_id to 0 as a fallback.")
85
- self._model.generation_config.pad_token_id = 0
86
- else:
87
- self._model = pipeline(self.task, model=self.model_name, tokenizer=self.tokenizer)
88
- return self._model
89
-
90
- # Text extraction agents
91
- class TextExtractorAgent:
92
- @staticmethod
93
- def extract_text(filepath, ext):
94
- """Extract text based on file type."""
95
- try:
96
- if ext == "pdf":
97
- return TextExtractorAgent.extract_text_from_pdf(filepath)
98
- elif ext in {"jpg", "jpeg", "png"}:
99
- return TextExtractorAgent.extract_text_from_image(filepath)
100
- elif ext == "docx":
101
- return TextExtractorAgent.extract_text_from_docx(filepath)
102
- elif ext in {"xlsx", "xls"}:
103
- return TextExtractorAgent.extract_text_from_excel(filepath)
104
- return None
105
- except Exception as e:
106
- logging.error(f"Text extraction failed: {e}")
107
- return None
108
-
109
- @staticmethod
110
- def extract_text_from_pdf(filepath):
111
- """Extract text from a PDF file."""
112
- text = ""
113
- with pdfplumber.open(filepath) as pdf:
114
- for page in pdf.pages:
115
- page_text = page.extract_text()
116
- if page_text:
117
- text += page_text + "\n"
118
- return text.strip() or None
119
-
120
- @staticmethod
121
- def extract_text_from_image(filepath):
122
- """Extract text from an image using OCR."""
123
- image = cv2.imread(filepath)
124
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
125
- _, processed = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
126
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
127
- processed_path = temp_file.name
128
- cv2.imwrite(processed_path, processed)
129
- text = pytesseract.image_to_string(Image.open(processed_path), lang='eng')
130
- os.remove(processed_path)
131
- return text.strip() or None
132
-
133
- @staticmethod
134
- def extract_text_from_docx(filepath):
135
- """Extract text from a DOCX file."""
136
- doc = Document(filepath)
137
- text = "\n".join([para.text for para in doc.paragraphs])
138
- return text.strip() or None
139
-
140
- @staticmethod
141
- def extract_text_from_excel(filepath):
142
- """Extract text from an Excel file."""
143
- dfs = pd.read_excel(filepath, sheet_name=None)
144
- text = "\n".join([
145
- "\n".join([
146
- " ".join(map(str, df[col].dropna()))
147
- for col in df.columns
148
- ])
149
- for df in dfs.values()
150
- ])
151
- return text.strip() or None
152
-
153
- # PHI scrubbing agent
154
- class PHIScrubberAgent:
155
- @staticmethod
156
- def scrub_phi(text):
157
- """Remove sensitive personal health information (PHI)."""
158
- try:
159
- text = re.sub(r'\b(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b', '[PHONE]', text)
160
- text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w{2,4}\b', '[EMAIL]', text)
161
- text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text)
162
- text = re.sub(r'\b\d{1,5}\s+\w+\s+(Street|St|Avenue|Ave|Boulevard|Blvd|Road|Rd|Lane|Ln)\b', '[ADDRESS]', text, flags=re.IGNORECASE)
163
- text = re.sub(r'\bDr\.?\s+[A-Z][a-z]+\s+[A-Z][a-z]+\b', 'Dr. [NAME]', text)
164
- text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text)
165
- except Exception as e:
166
- logging.error(f"PHI scrubbing failed: {e}")
167
- return text
168
-
169
- # Summarization agent
170
- class SummarizerAgent:
171
- def __init__(self, summarization_model_loader):
172
- self.summarization_model_loader = summarization_model_loader
173
-
174
- def generate_summary(self, text):
175
- """Generate a summary of the provided text."""
176
- model = self.summarization_model_loader.load()
177
- try:
178
- summary_result = model(text, do_sample=False)
179
- return summary_result[0]['summary_text'].strip()
180
- except Exception as e:
181
- logging.error(f"Summary generation failed: {e}")
182
- return "Summary generation failed."
183
-
184
- def allowed_file(filename, allowed_extensions=ALLOWED_EXTENSIONS):
185
- """Check if the file extension is allowed."""
186
- return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions
187
-
188
- # Knowledge Base
189
- class KnowledgeBase:
190
- def __init__(self, documents):
191
- self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
192
- self.documents = documents
193
- self.embeddings = self.embedding_model.encode(documents)
194
- self.dimension = self.embedding_model.get_sentence_embedding_dimension()
195
- self.index = faiss.IndexFlatL2(self.dimension)
196
- self.index.add(self.embeddings)
197
-
198
- def retrieve_relevant_info(self, query, top_k=3):
199
- """Retrieve relevant medical information from the knowledge base."""
200
- query_embedding = self.embedding_model.encode([query])
201
- distances, indices = self.index.search(query_embedding, top_k)
202
- relevant_texts = [self.documents[i] for i in indices[0]]
203
- return relevant_texts
204
-
205
- # Medical data extraction agent
206
- class MedicalDataExtractorAgent:
207
- def __init__(self, model_loader, knowledge_base):
208
- self.model_loader = model_loader
209
- self.knowledge_base = knowledge_base
210
-
211
- def retrieve_relevant_info(self, query, top_k=3):
212
- """Retrieve relevant medical information from the knowledge base."""
213
- query_embedding = self.knowledge_base.embedding_model.encode([query])
214
- distances, indices = self.knowledge_base.index.search(query_embedding, top_k)
215
- relevant_texts = [self.knowledge_base.documents[i] for i in indices[0]]
216
- return relevant_texts
217
-
218
- def extract_medical_data(self, text):
219
- """Extract structured medical data from text using Agentic RAG."""
220
- try:
221
- default_schema = {
222
- "patient_name": "[NAME]",
223
- "age": None,
224
- "gender": None,
225
- "diagnosis": [],
226
- "symptoms": [],
227
- "medications": [],
228
- "allergies": [],
229
- "vitals": {
230
- "blood_pressure": None,
231
- "heart_rate": None,
232
- "temperature": None
233
- },
234
- "notes": ""
235
- }
236
- prompt = f"""
237
- ### Instruction:
238
- Extract structured medical data from the following text as a JSON whose parameters are enclosed in "" and without any \.
239
- The JSON should include patientname, age, gender, medications, allergies, diagnosis, symptoms, vitals, and notes.
240
- ### Text:
241
- {text}
242
- ### Response:
243
- """
244
- model = self.model_loader.load()
245
- tokenizer = self.model_loader._tokenizer
246
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
247
- outputs = model.generate(
248
- inputs.input_ids,
249
- num_return_sequences=1,
250
- temperature=0.7,
251
- top_p=0.9,
252
- do_sample=True
253
- )
254
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
255
- logging.info(f"Model response: {response}")
256
- json_start = response.find("{")
257
- json_end = response.rfind("}") + 1
258
- if json_start == -1 or json_end == -1:
259
- raise ValueError("No JSON found in the model response.")
260
- structured_data = json.loads(response[json_start:json_end])
261
- normalized_data = self.normalize_json_output(structured_data, default_schema)
262
- if normalized_data["vitals"]["blood_pressure"] and isinstance(normalized_data["vitals"]["blood_pressure"], str):
263
- normalized_data["vitals"]["blood_pressure"] = normalized_data["vitals"]["blood_pressure"].strip('"')
264
- return json.dumps(normalized_data)
265
- except json.JSONDecodeError as e:
266
- logging.error(f"JSON parsing failed: {e}")
267
- return json.dumps({"error": f"Failed to parse JSON: {str(e)}"})
268
- except Exception as e:
269
- logging.error(f"Error extracting medical data: {e}")
270
- return json.dumps({"error": f"Failed to extract medical data: {str(e)}"})
271
-
272
- @staticmethod
273
- def normalize_json_output(model_output, default_schema):
274
- """Normalize the model's JSON output to match the default schema."""
275
- try:
276
- normalized_output = default_schema.copy()
277
- for key in normalized_output:
278
- if key in model_output:
279
- normalized_output[key] = model_output[key]
280
- return normalized_output
281
- except Exception as e:
282
- logging.error(f"Failed to normalize JSON: {e}")
283
- return default_schema
284
-
285
- # Initialize lazy loaders
286
- medalpaca_model_loader = LazyModelLoader(
287
- model_name="stanford-crfm/BioMedLM",
288
- task="text-generation"
289
- )
290
- summarization_model_loader = LazyModelLoader("google-t5/t5-small", "summarization")
291
-
292
- # Initialize knowledge base
293
- medical_documents = [
294
- "Hypertension is a chronic condition characterized by elevated blood pressure.",
295
- "Diabetes is a metabolic disorder that affects blood sugar levels.",
296
- "Common symptoms of chest pain include pressure, tightness, or discomfort in the chest."
297
- ]
298
- knowledge_base = KnowledgeBase(medical_documents)
299
-
300
- # Initialize agents
301
- text_extractor_agent = TextExtractorAgent()
302
- phi_scrubber_agent = PHIScrubberAgent()
303
- medical_data_extractor_agent = MedicalDataExtractorAgent(medalpaca_model_loader, knowledge_base)
304
- summarizer_agent = SummarizerAgent(summarization_model_loader)
305
-
306
- # NER to Detect medical info
307
- CONFIDENCE_THRESHOLD = 0.80
308
-
309
- def extract_medical_entities(text, ner_pipeline):
310
- if not text or not text.strip():
311
- return ["No medical entities found"]
312
- if ner_pipeline is None:
313
- print("⚠️ NER model is not loaded, skipping entity extraction.")
314
- return ["No medical entities found"]
315
-
316
- ner_results = ner_pipeline(text)
317
- relevant_entities = {
318
- "Disease", "MedicalCondition", "Symptom", "Sign_or_Symptom",
319
- "B-DISEASE", "I-DISEASE",
320
- "Test", "Measurement", "B-TEST", "I-TEST", "Lab_value", "B-Lab_value", "I-Lab_value",
321
- "Medication", "B-MEDICATION", "I-MEDICATION", "Treatment",
322
- "Procedure", "B-Diagnostic_procedure", "I-Diagnostic_procedure",
323
- "Anatomical_site", "Body_Part", "Organ_or_Tissue",
324
- "Diagnostic_procedure", "Surgical_Procedure", "Therapeutic_Procedure",
325
- "Health_condition", "B-Health_condition", "I-Health_condition",
326
- "Pathological_Condition", "Clinical_Event",
327
- "Chemical_Substance", "B-Chemical_Substance", "I-Chemical_Substance",
328
- "Biological_Entity", "B-Biological_Entity", "I-Biological_Entity"
329
- }
330
-
331
- medical_entities = set()
332
- for ent in ner_results:
333
- entity_label = ent.get("entity_group") or ent.get("entity")
334
- if entity_label in relevant_entities and ent["score"] >= CONFIDENCE_THRESHOLD:
335
- word = ent["word"].lower().strip().replace("-", "")
336
- if len(word) > 2:
337
- medical_entities.add(word)
338
-
339
- if len(medical_entities) >= 5:
340
- return list(medical_entities)
341
-
342
- return ["No medical entities found"]
343
-
344
- # Validation: Check File Size
345
- def check_file_size(file):
346
- file.seek(0, os.SEEK_END)
347
- size = file.tell()
348
- file.seek(0)
349
- extension = file.filename.rsplit('.', 1)[-1].lower()
350
- if extension in {'pdf', 'docx'} and size > MAX_SIZE_PDF_DOCS:
351
- return False, f"File {file.filename} exceeds 1GB size limit"
352
- elif extension in {'jpg', 'jpeg', 'png'} and size > MAX_SIZE_IMAGES:
353
- return False, f"Image {file.filename} exceeds 500MB size limit"
354
- return True, None
355
-
356
- def extract_patient_name(text, qa_pipeline):
357
- """Extracts patient name using the given QA pipeline."""
358
- if not text or not qa_pipeline:
359
- return None
360
- try:
361
- result = qa_pipeline(
362
- question="What is the patient's name?",
363
- context=text
364
- )
365
- return result.get("answer", "").strip()
366
- except Exception as e:
367
- print(f"⚠️ Error extracting patient name: {e}")
368
- return None
369
-
370
- def normalize_name(name):
371
- """Cleans and normalizes names for comparison, removing salutations dynamically."""
372
- if not name:
373
- return ""
374
- name = name.lower().strip()
375
- name = re.sub(r"[^\w\s]", "", name)
376
- name = re.sub(r"^\b\w{1,5}\b\s+", "", name)
377
- return name
378
-
379
- def validate_patient_name(extracted_text, patient_name, filename, qa_pipeline):
380
- """Validates if the extracted name matches the registered patient name."""
381
- detected_name = extract_patient_name(extracted_text, qa_pipeline)
382
- if not detected_name:
383
- return jsonify({"error": f"Could not determine patient name from {filename}"}), 400
384
- normalized_detected_name = normalize_name(detected_name)
385
- normalized_patient_name = normalize_name(patient_name)
386
- if normalized_detected_name not in normalized_patient_name:
387
- return jsonify({
388
- "error": f"Document '{filename}' does not belong to {patient_name}. Found: {detected_name}"
389
- }), 400
390
- return None
391
-
392
- def is_blurred(image_path, variance_threshold=150):
393
- try:
394
- image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
395
- if image is None:
396
- print(f"❌ Error: Unable to read image {image_path}")
397
- return True
398
- laplacian_var = cv2.Laplacian(image, cv2.CV_64F).var()
399
- print(f"🔍 Blur Check: Variance={laplacian_var} (Threshold={variance_threshold})")
400
- edges = cv2.Canny(image, 50, 150)
401
- edge_density = np.mean(edges)
402
- print(f"📏 Edge Density: {edge_density}")
403
- return laplacian_var < variance_threshold and edge_density < 10
404
- except Exception as e:
405
- print(f"❌ Error detecting blur: {e}")
406
- return True
407
-
408
- def extract_text_from_image(filepath):
409
- try:
410
- if is_blurred(filepath):
411
- return "Image is too blurry, OCR failed."
412
- image = cv2.imread(filepath)
413
- if image is None:
414
- return "Image could not be read."
415
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
416
- gray = cv2.GaussianBlur(gray, (5, 5), 0)
417
- gray = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
418
- cv2.THRESH_BINARY, 11, 2)
419
- kernel = np.ones((2,2), np.uint8)
420
- gray = cv2.dilate(gray, kernel, iterations=1)
421
- processed_path = f"{filepath}_processed.png"
422
- cv2.imwrite(processed_path, gray)
423
- text = pytesseract.image_to_string(Image.open(processed_path), lang='eng').strip()
424
- words = text.split()
425
- if len(words) < 5:
426
- return "OCR failed to extract meaningful text."
427
- return text
428
- except Exception as e:
429
- print(f"❌ Error processing {filepath}: {e}")
430
- return "Failed to extract text"
431
-
432
- def extract_text_from_pdf(filepath, password=None):
433
- """Extract text from PDFs using pdfplumber (faster) or OCR (if needed)."""
434
- text = ""
435
- try:
436
- reader = PdfReader(filepath)
437
- if reader.is_encrypted:
438
- if not password:
439
- print("🔒 PDF is encrypted but no password was provided.")
440
- return {"error": "File is password-protected. Please provide a password."}, 401
441
- decryption_result = reader.decrypt(password)
442
- if decryption_result == 0:
443
- print("❌ Incorrect password provided!")
444
- return {"error": "Invalid password provided."}, 403
445
- else:
446
- print("✅ PDF successfully decrypted!")
447
- text = "\n".join([page.extract_text() or "" for page in reader.pages])
448
- if text.strip():
449
- return text.strip(), 200
450
- with pdfplumber.open(filepath) as pdf:
451
- for page in pdf.pages:
452
- page_text = page.extract_text()
453
- if page_text:
454
- text += page_text + "\n"
455
- if text.strip():
456
- return text.strip(), 200
457
- images = convert_from_path(filepath)
458
- with ThreadPoolExecutor(max_workers=5) as pool:
459
- ocr_text = list(pool.map(lambda img: pytesseract.image_to_string(img, lang='eng'), images))
460
- return ("\n".join(ocr_text).strip(), 200) if ocr_text else ("No text found", 415)
461
- except Exception as e:
462
- print(f"❌ Error processing PDF {filepath}: {e}")
463
- return "Failed to extract text"
464
-
465
- def extract_text_from_docx(filepath):
466
- doc = Document(filepath)
467
- text = "\n".join([para.text for para in doc.paragraphs])
468
- return text.strip() or None
469
-
470
- def clean_result(value):
471
- value = re.sub(r"\s+", " ", value)
472
- value = re.sub(r"[-_:]+", " ", value)
473
- value = re.sub(r"[^\x00-\x7F]+", " ", value)
474
- return value if value else "Not Available"
475
-
476
- def mask_sensitive_info(text):
477
- text = re.sub(r'(?<=\b\w{2})\w+(?=\s\w{2,})', '***', text)
478
- text = re.sub(r'\b(\d{2})\d{2}-(\d{2})\d{2}-(\d{2})\d{2}\b', r'**\2-**\3-**', text)
479
- text = re.sub(r'\b(\d{8})(\d{2})\b', r'********\2', text)
480
- return text
481
-
482
- # API Endpoints
483
- @app.route('/extract_medical_data', methods=['POST'])
484
- def extract_medical_data():
485
- """Extract structured medical data from raw text."""
486
- try:
487
- data = request.json
488
- if "text" not in data or not data["text"].strip():
489
- return jsonify({"error": "No valid text provided"}), 400
490
- raw_text = data["text"]
491
- clean_text = phi_scrubber_agent.scrub_phi(raw_text)
492
- structured_data = medical_data_extractor_agent.extract_medical_data(clean_text)
493
- return jsonify(json.loads(structured_data)), 200
494
- except Exception as e:
495
- logging.error(f"Failed to extract medical data: {e}")
496
- return jsonify({"error": f"Extraction Error: {str(e)}"}), 500
497
-
498
- @app.route('/api/transcribe', methods=['POST'])
499
- def transcribe_audio():
500
- """Transcribe audio files into text."""
501
- if 'audio' not in request.files:
502
- abort(400, description="No audio file provided")
503
- audio_file = request.files['audio']
504
- if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
505
- abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
506
- filename = secure_filename(audio_file.filename)
507
- audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
508
- audio_file.save(audio_path)
509
- try:
510
- result = whisper_model.transcribe(audio_path)
511
- transcribed_text = result["text"]
512
- os.remove(audio_path)
513
- return jsonify({"transcribed_text": transcribed_text}), 200
514
- except Exception as e:
515
- logging.error(f"Transcription failed: {str(e)}")
516
- return jsonify({"error": f"Transcription failed: {str(e)}"}), 500
517
-
518
- @app.route('/api/generate_summary', methods=['POST'])
519
- def generate_summary():
520
- """Generate a summary from the provided text."""
521
- data = request.json
522
- if "text" not in data or not data["text"].strip():
523
- return jsonify({"error": "No valid text provided"}), 400
524
- context = data["text"]
525
- clean_text = phi_scrubber_agent.scrub_phi(context)
526
- summary = summarizer_agent.generate_summary(clean_text)
527
- return jsonify({"summary": summary}), 200
528
-
529
- @app.route('/api/extract_medical_data_from_audio', methods=['POST'])
530
- def extract_medical_data_from_audio():
531
- """Extract medical data from transcribed audio."""
532
- if 'audio' not in request.files:
533
- abort(400, description="No audio file provided")
534
- audio_file = request.files['audio']
535
- if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
536
- abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
537
- logging.info(audio_file.filename)
538
- logging.info(app.config['UPLOAD_FOLDER'])
539
- filename = secure_filename(audio_file.filename)
540
- logging.info(filename)
541
- audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
542
- logging.info(audio_path)
543
- audio_file.save(audio_path)
544
- try:
545
- result = whisper_model.transcribe(audio_path)
546
- transcribed_text = result["text"]
547
- clean_text = phi_scrubber_agent.scrub_phi(transcribed_text)
548
- summary = summarizer_agent.generate_summary(transcribed_text)
549
- structured_data = medical_data_extractor_agent.extract_medical_data(transcribed_text)
550
- response = {
551
- "transcribed_text": transcribed_text,
552
- "summary": summary,
553
- "medical_chart": json.loads(structured_data)
554
- }
555
- os.remove(audio_path)
556
- return jsonify(response), 200
557
- except Exception as e:
558
- logging.error(f"Processing failed: {str(e)}")
559
- return jsonify({"error": f"Processing failed: {str(e)}"}), 500
560
-
561
- @app.route('/upload', methods=['POST'])
562
- def upload_file():
563
- files = request.files.getlist("file")
564
- patient_name = request.form.get("patient_name", "").strip()
565
- password = request.form.get("password")
566
- qa_model_name = request.form.get("qa_model_name")
567
- qa_model_type = request.form.get("qa_model_type")
568
- ner_model_name = request.form.get("ner_model_name")
569
- ner_model_type = request.form.get("ner_model_type")
570
- summarizer_model_name = request.form.get("summarizer_model_name")
571
- summarizer_model_type = request.form.get("summarizer_model_type")
572
-
573
- if not files:
574
- return jsonify({"error": "No file uploaded"}), 400
575
-
576
- try:
577
- qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name)
578
- print(f"✅ QA Model Loaded: {qa_model_name}")
579
- except Exception as e:
580
- return jsonify({"error": f"QA model load failed: {str(e)}"}), 500
581
-
582
- try:
583
- ner_pipeline = pipeline(task=ner_model_type, model=ner_model_name)
584
- print(f"✅ NER Model Loaded: {ner_model_name}")
585
- except Exception as e:
586
- return jsonify({"error": f"NER model load failed: {str(e)}"}), 500
587
-
588
- try:
589
- summarizer_pipeline = pipeline(task=summarizer_model_type, model=summarizer_model_name)
590
- print(f"✅ Summarizer Model Loaded: {summarizer_model_name}")
591
- except Exception as e:
592
- return jsonify({"error": f"Summarizer model load failed: {str(e)}"}), 500
593
-
594
- extracted_data = []
595
- print(patient_name)
596
-
597
- for file in files:
598
- if file.filename == '':
599
- continue
600
- if not allowed_file(file.filename):
601
- return jsonify({"error": f"Unsupported file type: {file.filename}. Supported file types are: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
602
- if not patient_name:
603
- return jsonify({"error": "Patient name is missing"}), 400
604
- valid_size, error_message = check_file_size(file)
605
- if not valid_size:
606
- return jsonify({"error": error_message}), 400
607
-
608
- filename = secure_filename(file.filename)
609
- filepath = os.path.join(UPLOAD_FOLDER, filename)
610
- file.save(filepath)
611
-
612
- extracted_text = None
613
-
614
- if filename.endswith(".pdf"):
615
- result = extract_text_from_pdf(filepath, password)
616
- if isinstance(result, tuple):
617
- extracted_text, status_code = result
618
- else:
619
- extracted_text = result
620
- status_code = 200
621
- if isinstance(extracted_text, dict) and "error" in extracted_text:
622
- return jsonify(extracted_text), status_code
623
- elif filename.endswith(".docx"):
624
- extracted_text = extract_text_from_docx(filepath)
625
- elif filename.endswith((".jpg", ".jpeg", ".png", ".svg")):
626
- extracted_text = extract_text_from_image(filepath)
627
-
628
- if not extracted_text or extracted_text == "No text found":
629
- return jsonify({"error": f"Failed to extract text from {filename}"}), 415
630
- if extracted_text in ["Image is too blurry, OCR failed.", "OCR failed to extract meaningful text."]:
631
- return jsonify({"error": f"'{filename}' is too blurry or text is unreadable."}), 422
632
-
633
- skip_medical_check = request.form.get("skip_medical_check", "false").lower() == "true"
634
- if not skip_medical_check:
635
- ner_results = ner_pipeline(extracted_text)
636
- medical_entities = list(set([r["word"] for r in ner_results if r["entity"].startswith("B-") or r["entity"].startswith("I-")]))
637
- print(f"Medical entities found: {medical_entities}")
638
- if not medical_entities:
639
- return jsonify({"error": f"'{filename}' is not medically relevant"}), 406
640
- else:
641
- print(f"Skipping Medical Validation for {filename}")
642
-
643
- skip_patient_check = request.form.get("skip_patient_check", "false").lower() == "true"
644
- if not skip_patient_check:
645
- try:
646
- error_response = validate_patient_name(extracted_text, patient_name, filename, qa_pipeline)
647
- if error_response:
648
- return error_response
649
- except Exception as e:
650
- return jsonify({"error": f"Patient name validation failed: {str(e)}"}), 500
651
- else:
652
- print(f"Skipping Patient Name Validation for {filename}")
653
-
654
- try:
655
- summary = summarizer_pipeline(extracted_text, max_length=350, min_length=50, do_sample=False)[0]["summary_text"]
656
- except Exception as e:
657
- summary = "Summary failed"
658
- print(f"⚠️ Error summarizing: {e}")
659
-
660
- extracted_data.append({
661
- "file": filename,
662
- "extracted_text": extracted_text,
663
- "summary": summary,
664
- "message": "Successful"
665
- })
666
-
667
- extracted_text = None
668
- summary = None
669
-
670
- if not extracted_data:
671
- return jsonify({"error": "No valid medical files processed"}), 400
672
-
673
- return jsonify({"extracted_data": extracted_data}), 200
674
-
675
- @app.route('/extract_medical_data_questions', methods=['POST'])
676
- def extract_medical_data_questions():
677
- """Extract medical data based on predefined questions."""
678
- data = request.json
679
- qa_model_name = data.get("qa_model_name")
680
- qa_model_type = data.get("qa_model_type")
681
- if "extracted_data" not in data:
682
- return jsonify({"error": "Missing 'extracted_data' in request"}), 400
683
-
684
- if not qa_model_name or not qa_model_type:
685
- return jsonify({"error": "Missing 'model_name' or 'model_type'"}), 400
686
-
687
- try:
688
- print(f"🌀 Loading model: {qa_model_name} ({qa_model_type})")
689
- qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name)
690
- loaded_model_name = qa_pipeline.model.config._name_or_path
691
- loaded_model_type = qa_pipeline.task
692
- print(f"✅ Model loaded: {loaded_model_name}")
693
- except Exception as e:
694
- print("❌ Error loading model:", str(e))
695
- return jsonify({"error": f"Could not load model: {str(e)}"}), 500
696
-
697
- questions = {
698
- "Patient Name": "What is the patient's name?",
699
- "Age": "What is the patient's age?",
700
- "Gender": "What is the patient's gender?",
701
- "Date of Birth": "What is the patient's date of birth?",
702
- "Patient ID": "What is the patient ID?",
703
- "Reason for Visit": "What is the reason for the patient's visit?",
704
- "Physician": "Who is the physician in charge of the patient?",
705
- "Test Date": "What is the test date?",
706
- "Hemoglobin": "What is the patient's hemoglobin level?",
707
- "Blood Glucose (Fasting)": "What is the patient's fasting blood glucose level?",
708
- "Total Cholesterol": "What is the total cholesterol level?",
709
- "LDL Cholesterol": "What is the LDL cholesterol level?",
710
- "HDL Cholesterol": "What is the HDL cholesterol level?",
711
- "Serum Creatinine": "What is the serum creatinine level?",
712
- "Vitamin D (25-OH)": "What is the patient's Vitamin D level?",
713
- "Height": "What is the patient's height?",
714
- "Weight": "What is the patient's weight?",
715
- "Blood Pressure (Systolic)": "What is the patient's systolic blood pressure?",
716
- "Blood Pressure (Diastolic)": "What is the patient's diastolic blood pressure?",
717
- "Recommendations": "What are the recommendations based on the test results?"
718
- }
719
-
720
- structured_response = {"extracted_data": []}
721
-
722
- for file_data in data["extracted_data"]:
723
- filename = file_data["file"]
724
- context = file_data["extracted_text"]
725
-
726
- if not context:
727
- structured_response["extracted_data"].append({
728
- "file": filename,
729
- "medical_terms": "No data extracted",
730
- })
731
- continue
732
-
733
- extracted_info = {}
734
-
735
- for key, question in questions.items():
736
- try:
737
- result = qa_pipeline(question=question, context=context)
738
- extracted_info[key] = clean_result(result.get("answer", "Not Available"))
739
- except:
740
- extracted_info[key] = "Error extracting"
741
-
742
- categorized_data = [
743
- {
744
- "name": "Patient Information",
745
- "fields": [
746
- {"label": "Patient Name", "value": extracted_info.get("Patient Name", "")},
747
- {"label": "Date of Birth", "value": extracted_info.get("Date of Birth", "")},
748
- {"label": "Gender", "value": extracted_info.get("Gender", "")},
749
- {"label": "Patient ID", "value": extracted_info.get("Patient ID", "")}
750
- ]
751
- },
752
- {
753
- "name": "Vitals",
754
- "fields": [
755
- {"label": "Height", "value": extracted_info.get("Height", "")},
756
- {"label": "Weight", "value": extracted_info.get("Weight", "")},
757
- {"label": "Blood Pressure", "value": f"{extracted_info.get('Blood Pressure (Systolic)', '')}/{extracted_info.get('Blood Pressure (Diastolic)', '')} mmHg"},
758
- {"label": "Hemoglobin", "value": extracted_info.get("Hemoglobin", "")},
759
- {"label": "Serum Creatinine", "value": extracted_info.get("Serum Creatinine", "")}
760
- ]
761
- },
762
- {
763
- "name": "Lab Results",
764
- "fields": [
765
- {"label": "Blood Glucose (Fasting)", "value": extracted_info.get("Blood Glucose (Fasting)", "")},
766
- {"label": "Total Cholesterol", "value": extracted_info.get("Total Cholesterol", "")},
767
- {"label": "LDL Cholesterol", "value": extracted_info.get("LDL Cholesterol", "")},
768
- {"label": "HDL Cholesterol", "value": extracted_info.get("HDL Cholesterol", "")},
769
- {"label": "Vitamin D (25-OH)", "value": extracted_info.get("Vitamin D (25-OH)", "")}
770
- ]
771
- },
772
- {
773
- "name": "Medical Notes",
774
- "fields": [
775
- {"label": "Reason for Visit", "value": extracted_info.get("Reason for Visit", "")},
776
- {"label": "Physician", "value": extracted_info.get("Physician", "")},
777
- {"label": "Test Date", "value": extracted_info.get("Test Date", "")},
778
- {"label": "Recommendations", "value": extracted_info.get("Recommendations", "")}
779
- ]
780
- }
781
- ]
782
- structured_response["extracted_data"].append({
783
- "file": filename,
784
- "medical_terms": extracted_info,
785
- "categorized_data": categorized_data,
786
- "model_used": loaded_model_name,
787
- "model_type": loaded_model_type
788
- })
789
-
790
- save_data_to_storage(filename, structured_response)
791
- print(f"✅ Extracted data saved to: {os.path.join(UPLOAD_FOLDER, f'{filename}.json')}")
792
-
793
- return jsonify(structured_response), 200
794
-
795
- def get_data_from_storage(filename):
796
- try:
797
- filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json")
798
- print(f"🔍 Looking for file at: {filepath}")
799
- if not os.path.exists(filepath):
800
- print(f"🚫 File not found at: {filepath}")
801
- return None
802
- with open(filepath, "r") as file:
803
- data = json.load(file)
804
- print(f"✅ File found and loaded: {filepath}")
805
- return data
806
- except Exception as e:
807
- print(f"🚨 Error loading data: {e}")
808
- return None
809
-
810
- def save_data_to_storage(filename, data):
811
- try:
812
- filename = filename.rsplit(".", 1)[0]
813
- filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json")
814
- print(f"Saving to: {filepath}")
815
- print(f"Directory exists: {os.path.exists(UPLOAD_FOLDER)}")
816
- if not os.path.exists(UPLOAD_FOLDER):
817
- print(f"Directory not found. Creating: {UPLOAD_FOLDER}")
818
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
819
- with open(filepath, "w") as file:
820
- json.dump(data, file)
821
- print(f"✅ Data saved successfully to {filepath}")
822
- except Exception as e:
823
- print(f"🚨 Exception during save: {e}")
824
-
825
- @app.route('/get_updated_medical_data', methods=['GET'])
826
- def get_updated_data():
827
- file_name = request.args.get('file')
828
- if not file_name:
829
- return jsonify({"error": "File name is required"}), 400
830
- file_name = file_name.rsplit(".", 1)[0]
831
- updated_data = get_data_from_storage(file_name)
832
- if updated_data:
833
- return jsonify({"file": file_name, "data": updated_data}), 200
834
- else:
835
- return jsonify({"error": f"File '{file_name}' not found"}), 404
836
-
837
- @app.route('/update_medical_data', methods=['PUT'])
838
- def update_medical_data():
839
- try:
840
- data = request.json
841
- print("Received data:", data)
842
- filename = data.get("file")
843
- filename = filename.rsplit(".", 1)[0]
844
- updates = data.get("updates", [])
845
- if not filename or not updates:
846
- return jsonify({"error": "File name or updates missing"}), 400
847
- existing_data = get_data_from_storage(filename)
848
- if not existing_data:
849
- return jsonify({"error": f"File '{filename}' not found"}), 404
850
- for update in updates:
851
- category = update.get("category")
852
- field = update.get("field")
853
- new_value = update.get("value")
854
- updated = False
855
- for cat in existing_data.get("extracted_data", []):
856
- for categorized_data in cat.get("categorized_data", []):
857
- if categorized_data.get("name") == category:
858
- for fld in categorized_data.get("fields", []):
859
- if fld.get("label") == field:
860
- print(f"🔄 Updating {category} -> {field} from '{fld['value']}' to '{new_value}'")
861
- fld["value"] = new_value
862
- updated = True
863
- break
864
- if updated:
865
- break
866
- if updated:
867
- break
868
- save_data_to_storage(filename, existing_data)
869
- print("✅ Updated data:", existing_data)
870
- return jsonify({"message": "Data updated successfully", "updated_data": existing_data}), 200
871
- except Exception as e:
872
- print("❌ Error:", str(e))
873
- return jsonify({"error": str(e)}), 500
874
-
875
- @app.route('/')
876
- def home():
877
- return "Medical Data Extraction API is running!"
878
-
879
- if __name__ == '__main__':
880
- app.run(host='0.0.0.0', port=5000, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
document_based_extraction.py DELETED
@@ -1,1188 +0,0 @@
1
- import os, re, json
2
- import time, logging, functools
3
- import pytesseract
4
- import cv2
5
- import pdfplumber
6
- import numpy as np
7
- from PIL import Image
8
- from PyPDF2 import PdfReader
9
- from pdf2image import convert_from_path
10
- from flask import Flask, request, jsonify
11
- from flask_cors import CORS
12
- import torch
13
- from werkzeug.utils import secure_filename
14
- from docx import Document
15
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
16
- from concurrent.futures import ThreadPoolExecutor, as_completed
17
- from collections import defaultdict
18
- from huggingface_hub import login
19
-
20
-
21
- # -------------------- Logging Config -------------------- #
22
- logging.basicConfig(
23
- level=logging.INFO,
24
- format="%(asctime)s - %(levelname)s - %(message)s",
25
- handlers=[
26
- logging.FileHandler("app.log"),
27
- logging.StreamHandler()
28
- ]
29
- )
30
- logger = logging.getLogger(__name__)
31
-
32
- # -------------------- Execution Time Decorator -------------------- #
33
- def log_execution_time(level=logging.INFO):
34
- def decorator(func):
35
- @functools.wraps(func)
36
- def wrapper(*args, **kwargs):
37
- start_time = time.time()
38
- try:
39
- result = func(*args, **kwargs)
40
- duration = time.time() - start_time
41
- logger.log(level, f"⏱️ {func.__name__} executed in {duration:.6f} seconds")
42
- return result
43
- except Exception as e:
44
- duration = time.time() - start_time
45
- logger.exception(f"❌ Exception in {func.__name__} after {duration:.6f} seconds: {e}")
46
- raise
47
- return wrapper
48
- return decorator
49
-
50
-
51
- login(
52
- "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"
53
- ) # 🧠 This will store it and every model load will use it
54
-
55
- executor = ThreadPoolExecutor(max_workers=5)
56
- logger.info("Executor initialized with 5 workers")
57
-
58
- # Set Tesseract OCR Path
59
- # in Windows
60
- # pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
61
- # in Linux
62
- pytesseract.pytesseract.tesseract_cmd = "/usr/local/bin/tesseract"
63
-
64
- # Set up Flask app
65
- app = Flask(__name__)
66
- CORS(app)
67
-
68
- UPLOAD_FOLDER = "uploads"
69
- ALLOWED_EXTENSIONS = {"pdf", "jpg", "jpeg", "png", "svg", "docx", "doc"}
70
- app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
71
-
72
- # Set file size limits
73
- MAX_SIZE_PDF_DOCS = 1 * 1024 * 1024 * 1024 # *1GB*
74
- MAX_SIZE_IMAGES = 500 * 1024 * 1024 # *500MB*
75
-
76
-
77
- # # Load ClinicalBERT Model for Classification
78
- # try:
79
- # zero_shot_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
80
- # print("✅ zero_shot_classifier Model Loaded Successfully")
81
- # except Exception as e:
82
- # zero_shot_classifier = None
83
- # print("❌ Error loading ClinicalBERT Model:", str(e))
84
-
85
-
86
- if not os.path.exists(UPLOAD_FOLDER):
87
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
88
-
89
- # NER to Detect medical info
90
- CONFIDENCE_THRESHOLD = 0.80
91
-
92
-
93
- @log_execution_time()
94
- def extract_medical_entities(text):
95
- if not text or not text.strip():
96
- return ["No medical entities found"]
97
- if ner_pipeline is None: # type: ignore
98
- logger.warning("NER model is not loaded, skipping entity extraction.")
99
- return ["No medical entities found"]
100
-
101
- ner_results = ner_pipeline(text) # type: ignore
102
- relevant_entities = {
103
- # Diseases & Symptoms
104
- "Disease",
105
- "MedicalCondition",
106
- "Symptom",
107
- "Sign_or_Symptom",
108
- "B-DISEASE",
109
- "I-DISEASE",
110
- # Tests, Measurements, and Lab Values
111
- "Test",
112
- "Measurement",
113
- "B-TEST",
114
- "I-TEST",
115
- "Lab_value",
116
- "B-Lab_value",
117
- "I-Lab_value",
118
- # Medications, Treatments, and Procedures
119
- "Medication",
120
- "B-MEDICATION",
121
- "I-MEDICATION",
122
- "Treatment",
123
- "Procedure",
124
- "B-Diagnostic_procedure",
125
- "I-Diagnostic_procedure",
126
- # Body Parts & Medical Anatomy
127
- "Anatomical_site",
128
- "Body_Part",
129
- "Organ_or_Tissue",
130
- # Medical Procedures
131
- "Diagnostic_procedure",
132
- "Surgical_Procedure",
133
- "Therapeutic_Procedure",
134
- # Clinical Terms
135
- "Health_condition",
136
- "B-Health_condition",
137
- "I-Health_condition",
138
- "Pathological_Condition",
139
- "Clinical_Event",
140
- # Biological & Chemical Substances (Relevant to Lab Reports)
141
- "Chemical_Substance",
142
- "B-Chemical_Substance",
143
- "I-Chemical_Substance",
144
- "Biological_Entity",
145
- "B-Biological_Entity",
146
- "I-Biological_Entity",
147
- }
148
-
149
- medical_entities = set()
150
- for ent in ner_results:
151
- entity_label = ent.get("entity_group") or ent.get("entity")
152
- if entity_label in relevant_entities and ent["score"] >= CONFIDENCE_THRESHOLD:
153
- word = ent["word"].lower().strip().replace("-", "") # Normalize text
154
- if len(word) > 2: # Ignore short/junk words
155
- medical_entities.add(word)
156
-
157
- if len(medical_entities) >= 5:
158
- logger.info(f"Extracted {len(medical_entities)} medical entities")
159
- return list(medical_entities)
160
-
161
- logger.info("Not enough medical entities found")
162
- return ["No medical entities found"]
163
-
164
-
165
- # Validation: Check Allowed File Types
166
- def allowed_file(filename):
167
- return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
168
-
169
-
170
- # Validation: Check File Size
171
- def check_file_size(file):
172
- file.seek(0, os.SEEK_END)
173
- size = file.tell()
174
- file.seek(0)
175
- extension = file.filename.rsplit(".", 1)[-1].lower()
176
- logger.info(f"Checking file size for '{file.filename}' - Size: {size} bytes")
177
- if extension in {"pdf", "docx"} and size > MAX_SIZE_PDF_DOCS:
178
- logger.warning(f"{file.filename} exceeds 1GB limit")
179
- return False, f"File {file.filename} exceeds 1MB size limit"
180
- elif extension in {"jpg", "jpeg", "png"} and size > MAX_SIZE_IMAGES:
181
- logger.warning(f"{file.filename} exceeds 500MB image limit")
182
- return False, f"Image {file.filename} exceeds 500KB size limit"
183
- return True, None
184
-
185
-
186
- @log_execution_time()
187
- def extract_patient_name(text, qa_pipeline):
188
- if not text or not qa_pipeline:
189
- return None
190
- try:
191
- result = qa_pipeline(question="What is the patient's name?", context=text)
192
- answer = result.get("answer", "").strip()
193
- logger.info(f"Extracted patient name: {answer}")
194
- return answer
195
- except Exception as e:
196
- logger.error(f"Error extracting patient name: {e}")
197
- return None
198
-
199
-
200
- def normalize_name(name):
201
- """Cleans and normalizes names for comparison, removing salutations dynamically"""
202
- if not name:
203
- return ""
204
- name = name.lower().strip()
205
- name = re.sub(r"[^\w\s]", "", name)
206
- name = re.sub(r"^\b\w{1,5}\b\s+", "", name) # Matches short words at the start
207
- return name
208
-
209
-
210
- @log_execution_time()
211
- def validate_patient_name(extracted_text, patient_name, filename, qa_pipeline):
212
- """Validates if the extracted name matches the registered patient name"""
213
- detected_name = extract_patient_name(extracted_text, qa_pipeline)
214
- if not detected_name:
215
- logger.warning(f"Could not determine patient name from {filename}")
216
- return (
217
- jsonify({"error": f"Could not determine patient name from {filename}"}),
218
- 400,
219
- )
220
-
221
- normalized_detected_name = normalize_name(detected_name)
222
- normalized_patient_name = normalize_name(patient_name)
223
-
224
- if normalized_detected_name not in normalized_patient_name:
225
- logger.warning(
226
- f"Patient mismatch in file '{filename}': Found '{detected_name}'"
227
- )
228
- return (
229
- jsonify(
230
- {
231
- "error": f"Document '{filename}' does not belong to {patient_name}. Found: {detected_name}"
232
- }
233
- ),
234
- 400,
235
- )
236
- logger.info(f"Patient name validation passed for '{filename}'")
237
- return None # No error, validation passed
238
-
239
-
240
- # Check if the image is blurred using the Laplacian method
241
- def is_blurred(image_path, variance_threshold=150):
242
- try:
243
- image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
244
- if image is None:
245
- logger.error(f"Unable to read image: {image_path}")
246
- return True # Assume it's blurry if not readable
247
-
248
- # Compute Laplacian variance
249
- laplacian_var = cv2.Laplacian(image, cv2.CV_64F).var()
250
- logger.info(
251
- f"Blur Check on '{image_path}': Laplacian Variance = {laplacian_var:.2f} (Threshold = {variance_threshold})"
252
- )
253
-
254
- # Compute Edge Density (Additional Check)
255
- edges = cv2.Canny(image, 50, 150)
256
- edge_density = np.mean(edges)
257
- logger.info(f"Edge Density for '{image_path}': {edge_density:.2f}")
258
-
259
- is_blurry = laplacian_var < variance_threshold and edge_density < 10
260
- if is_blurry:
261
- logger.warning(f"Image '{image_path}' flagged as blurry.")
262
- return is_blurry
263
- except Exception as e:
264
- logger.exception(f"Exception during blur detection for '{image_path}': {e}")
265
- return True # Assume it's blurry on failure
266
-
267
-
268
- # Helper Function: Extract Text from Images (OCR) with Blur Detection
269
- @log_execution_time()
270
- def extract_text_from_image(filepath):
271
- try:
272
- # Check if the image is blurry
273
- if is_blurred(filepath):
274
- logger.warning(f"OCR skipped: '{filepath}' is too blurry.")
275
- return "Image is too blurry, OCR failed."
276
-
277
- image = cv2.imread(filepath)
278
- if image is None:
279
- logger.error(f"OCR failed: Unable to read image '{filepath}'.")
280
- return "Image could not be read."
281
-
282
- # Convert to Grayscale and Apply Thresholding
283
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
284
- gray = cv2.GaussianBlur(gray, (5, 5), 0)
285
- gray = cv2.adaptiveThreshold(
286
- gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
287
- )
288
-
289
- # Apply dilation (bolds the text) for better OCR accuracy
290
- kernel = np.ones((2, 2), np.uint8)
291
- gray = cv2.dilate(gray, kernel, iterations=1)
292
- processed_path = f"{filepath}_processed.png"
293
- cv2.imwrite(processed_path, gray)
294
- logger.info(f"Image preprocessed and saved: {processed_path}")
295
- text = pytesseract.image_to_string(
296
- Image.open(processed_path), lang="eng"
297
- ).strip()
298
- # Validate OCR output (Reject if too little text is extracted)
299
- word_count = len(text.split())
300
- logger.info(
301
- f"OCR completed for '{filepath}' with {word_count} words extracted."
302
- )
303
-
304
- if word_count < 5:
305
- logger.warning(f"OCR output too small for '{filepath}'. Might be junk.")
306
- return "OCR failed to extract meaningful text."
307
-
308
- return text
309
-
310
- except Exception as e:
311
- logger.exception(f"Error extracting text from image '{filepath}': {e}")
312
- return "Failed to extract text"
313
-
314
-
315
- # Helper Function: Extract Text from PDF
316
- @log_execution_time()
317
- def extract_text_from_pdf(filepath, password=None):
318
- """Extract text from PDFs using pdfplumber (faster) or OCR (if needed)."""
319
- text = ""
320
-
321
- try:
322
- logger.info(f"Starting PDF extraction: {filepath}")
323
- reader = PdfReader(filepath)
324
-
325
- if reader.is_encrypted:
326
- if not password:
327
- logger.warning("Encrypted PDF without password.")
328
- return {
329
- "error": "File is password-protected. Please provide a password."
330
- }, 401
331
-
332
- # ✅ Attempt to decrypt
333
- decryption_result = reader.decrypt(password)
334
- if decryption_result == 0: # Decryption failed
335
- logger.error("Incorrect password provided.")
336
- return {"error": "Invalid password provided."}, 403
337
- else:
338
- logger.info("PDF decryption successful.")
339
-
340
- text = "\n".join([page.extract_text() or "" for page in reader.pages])
341
- if text.strip():
342
- logger.info("Text extracted from decrypted PDF.")
343
- return text.strip(), 200
344
-
345
- # ✅ Now, use pdfplumber for text extraction
346
- with pdfplumber.open(filepath) as pdf:
347
- for page in pdf.pages:
348
- page_text = page.extract_text()
349
- if page_text:
350
- text += page_text + "\n"
351
-
352
- if text.strip():
353
- logger.info(
354
- f"PDF text extracted using pdfplumber: {len(text.split())} words."
355
- )
356
- return text.strip(), 200 # ✅ Always return a tuple (text, status)
357
-
358
- logger.info("No text found via pdfplumber. Falling back to OCR.")
359
- # ✅ Use OCR if the PDF has no selectable text
360
- images = convert_from_path(filepath)
361
- with ThreadPoolExecutor(max_workers=5) as pool:
362
- ocr_text = list(
363
- pool.map(
364
- lambda img: pytesseract.image_to_string(img, lang="eng"), images
365
- )
366
- )
367
-
368
- full_ocr_text = "\n".join(ocr_text).strip()
369
- logger.info(
370
- f"OCR fallback complete for PDF: {len(full_ocr_text.split())} words extracted."
371
- )
372
-
373
- return (full_ocr_text, 200) if full_ocr_text else ("No text found", 415)
374
-
375
- except Exception as e:
376
- logger.exception(f"Error during PDF processing: {filepath}")
377
- return "Failed to extract text"
378
-
379
-
380
- # Helper Function: Extract Text from DOCX
381
- @log_execution_time()
382
- def extract_text_from_docx(filepath):
383
- try:
384
- doc = Document(filepath)
385
- text = "\n".join([para.text for para in doc.paragraphs])
386
- word_count = len(text.split())
387
- logger.info(f"DOCX extracted from '{filepath}': {word_count} words.")
388
- return text.strip() or None
389
- except Exception as e:
390
- logger.exception(f"Failed to extract text from DOCX: {filepath}")
391
- return None
392
-
393
-
394
- # Masking function to hide sensitive data
395
- def mask_sensitive_info(text):
396
- text = re.sub(r"(?<=\b\w{2})\w+(?=\s\w{2,})", "*", text) # Mask names
397
- text = re.sub(
398
- r"\b(\d{2})\d{2}-(\d{2})\d{2}-(\d{2})\d{2}\b", r"\2-\3-", text
399
- ) # Mask DOB
400
- text = re.sub(r"\b(\d{8})(\d{2})\b", r"\2", text) # Mask phone numbers
401
- return text
402
-
403
-
404
- # ------------------Upload Documents ------------------ #
405
- # API Route: Upload File & Extract Text
406
- @app.route("/upload", methods=["POST"])
407
- @log_execution_time()
408
- def upload_file():
409
- logger.info("📥 Upload request received")
410
- files = request.files.getlist("file")
411
- patient_name = request.form.get("patient_name", "").strip()
412
- password = request.form.get("password") # Get password if provided
413
- # Dynamic model info from form
414
- qa_model_name = request.form.get("qa_model_name")
415
- qa_model_type = request.form.get("qa_model_type")
416
-
417
- ner_model_name = request.form.get("ner_model_name")
418
- ner_model_type = request.form.get("ner_model_type")
419
-
420
- summarizer_model_name = request.form.get("summarizer_model_name")
421
- summarizer_model_type = request.form.get("summarizer_model_type")
422
-
423
- if not files:
424
- logger.warning("No file uploaded")
425
- return jsonify({"error": "No file uploaded"}), 400
426
-
427
- # 🔌 Load models dynamically
428
- try:
429
- qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name)
430
- logger.info(f"✅ QA model loaded: {qa_model_name}")
431
- except Exception as e:
432
- logger.error(f"❌ QA model load failed: {e}")
433
- return jsonify({"error": f"QA model load failed: {str(e)}"}), 500
434
-
435
- try:
436
- ner_pipeline = pipeline(task=ner_model_type, model=ner_model_name)
437
- logger.info(f"✅ NER model loaded: {ner_model_name}")
438
- except Exception as e:
439
- logger.error(f"❌ NER model load failed: {e}")
440
- return jsonify({"error": f"NER model load failed: {str(e)}"}), 500
441
-
442
- try:
443
- summarizer_pipeline = pipeline(
444
- task=summarizer_model_type, model=summarizer_model_name
445
- )
446
- logger.info(f"✅ Summarizer model loaded: {summarizer_model_name}")
447
- except Exception as e:
448
- logger.error(f"❌ Summarizer model load failed: {e}")
449
- return jsonify({"error": f"Summarizer model load failed: {str(e)}"}), 500
450
-
451
- extracted_data = []
452
- print(patient_name)
453
-
454
- for file in files:
455
- logger.info(f"📂 Processing file: {file.filename}")
456
-
457
- if file.filename == "":
458
- logger.warning("Skipping unnamed file")
459
- continue # Skip empty file names
460
-
461
- if not allowed_file(file.filename):
462
- logger.warning(f"Unsupported file type: {file.filename}")
463
- return (
464
- jsonify(
465
- {
466
- "error": f"Unsupported file type: {file.filename}. Supported file types are: {', '.join(ALLOWED_EXTENSIONS)}"
467
- }
468
- ),
469
- 400,
470
- )
471
-
472
- if not patient_name:
473
- logger.warning("Patient name missing")
474
- return jsonify({"error": "Patient name is missing"}), 400
475
-
476
- # *Check file size*
477
- valid_size, error_message = check_file_size(file)
478
- if not valid_size:
479
- logger.warning(f"❌ File size validation failed: {error_message}")
480
- return jsonify({"error": error_message}), 400
481
-
482
- filename = secure_filename(file.filename)
483
- filepath = os.path.join(UPLOAD_FOLDER, filename)
484
- file.save(filepath)
485
- logger.info(f"✅ File saved: {filepath}")
486
-
487
- extracted_text = None
488
-
489
- # ✅ *Extract text based on file type*
490
- if filename.endswith(".pdf"):
491
- logger.info("🧾 Extracting text from PDF")
492
- result = extract_text_from_pdf(filepath, password)
493
-
494
- # ✅ If PDF requires a password, return 401
495
- if isinstance(result, tuple):
496
- extracted_text, status_code = result
497
- else:
498
- extracted_text = result
499
- status_code = 200
500
-
501
- if isinstance(extracted_text, dict) and "error" in extracted_text:
502
- logger.warning(f"⚠️ PDF extraction error: {extracted_text}")
503
- return jsonify(extracted_text), status_code
504
- elif filename.endswith(".docx"):
505
- extracted_text = extract_text_from_docx(filepath)
506
- elif filename.endswith((".jpg", ".jpeg", ".png", ".svg")):
507
- logger.info("🖼️ Extracting text from image")
508
- extracted_text = extract_text_from_image(filepath)
509
-
510
- if not extracted_text or extracted_text == "No text found":
511
- logger.warning(f"⚠️ No text extracted from {filename}")
512
- return (
513
- jsonify({"error": f"Failed to extract text from {filename}"}),
514
- 415,
515
- ) # Unsupported Media Type
516
-
517
- # reject blurred images
518
- if extracted_text in [
519
- "Image is too blurry, OCR failed.",
520
- "OCR failed to extract meaningful text.",
521
- ]:
522
- logger.warning(f"🔍 OCR failed or image too blurry: {filename}")
523
- return (
524
- jsonify(
525
- {"error": f"'{filename}' is too blurry or text is unreadable."}
526
- ),
527
- 422,
528
- ) # Unprocessable Entity
529
-
530
- # ✅ Medical Validation using NER
531
- skip_medical_check = (
532
- request.form.get("skip_medical_check", "false").lower() == "true"
533
- )
534
- if not skip_medical_check:
535
- logger.info("🧠 Running NER medical validation")
536
- start_time = time.time()
537
- ner_results = ner_pipeline(extracted_text)
538
- medical_entities = list(
539
- set(
540
- [
541
- r["word"]
542
- for r in ner_results
543
- if r["entity"].startswith("B-") or r["entity"].startswith("I-")
544
- ]
545
- )
546
- )
547
- elapsed_time = time.time() - start_time
548
- logger.info(f"⏱️ Medical entity validation took {elapsed_time:.2f}s")
549
-
550
- logger.info(f"🩺 Medical entities found: {medical_entities}")
551
- if not medical_entities:
552
- logger.warning(f"❌ No medical relevance in {filename}")
553
- return (
554
- jsonify({"error": f"'{filename}' is not medically relevant"}),
555
- 406,
556
- )
557
- else:
558
- logger.info(f"⏭️ Skipping medical validation for {filename}")
559
-
560
- # # ✅ Patient Name Validation using QA
561
- # skip_patient_check = request.form.get("skip_patient_check", "false").lower() == "true"
562
- # if not skip_patient_check:
563
- # try:
564
- # logger.info("🧍 Validating patient name")
565
- # start_time = time.time()
566
- # error_response = validate_patient_name(extracted_text, patient_name, filename,qa_pipeline)
567
- # elapsed_time = time.time() - start_time
568
- # logger.info(f"⏱️ Patient name validation took {elapsed_time:.2f}s")
569
-
570
- # if error_response:
571
- # return error_response
572
- # except Exception as e:
573
- # logger.error(f"❌ Patient name validation failed: {e}")
574
- # return jsonify({"error": f"Patient name validation failed: {str(e)}"}), 500
575
- # else:
576
- # logger.info(f"⏭️ Skipping patient name validation for {filename}")
577
-
578
- # ✨ Generate Summary using Summarizer
579
- try:
580
- logger.info("📝 Generating summary: %s", extracted_text)
581
-
582
- start_time = time.time()
583
- summary = summarizer_pipeline(
584
- extracted_text, max_length=350, min_length=50, do_sample=False
585
- )[0]["summary_text"]
586
- elapsed_time = time.time() - start_time
587
-
588
- logger.info(f"✅ Summary generated: {summary}")
589
- logger.info(f"⏱️ Summary generation took {elapsed_time:.2f} seconds")
590
- except Exception as e:
591
- summary = "Summary failed"
592
- logger.warning(f"⚠ Summary generation failed: {e}")
593
- # # Classify report type
594
- # report_type = classify_medical_document(extracted_text)
595
- # print(report_type)
596
- # ✅ Summarize extracted text
597
- extracted_data.append(
598
- {
599
- "file": filename,
600
- # "document_type": report_type,
601
- "extracted_text": extracted_text,
602
- "summary": summary,
603
- "message": "Successful",
604
- }
605
- )
606
- logger.info(f"✅ Finished processing file: {filename}")
607
-
608
- if not extracted_data:
609
- logger.warning("❌ No valid medical files processed")
610
- return jsonify({"error": "No valid medical files processed"}), 400
611
-
612
- logger.info("📦 Upload processing completed successfully")
613
- return jsonify({"extracted_data": extracted_data}), 200
614
-
615
-
616
- # # API Route: Extract Medical Data Based on Predefined Questions
617
- # @app.route('/extract_medical_data', methods=['POST'])
618
- # def extract_medical_data():
619
- # data = request.json
620
- # print(f"📥 Incoming request data: {data}")
621
-
622
- # qa_model_name = data.get("qa_model_name")
623
- # qa_model_type = data.get("qa_model_type")
624
-
625
- # if "extracted_data" not in data:
626
- # return jsonify({"error": "Missing 'extracted_data' in request"}), 400
627
-
628
- # if not qa_model_name or not qa_model_type:
629
- # return jsonify({"error": "Missing 'model_name' or 'model_type'"}), 400
630
-
631
- # try:
632
- # print(f"🌀 Loading model: {qa_model_name} ({qa_model_type})")
633
- # qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name)
634
- # print(f"✅ Model loaded: {qa_pipeline.model.config._name_or_path}")
635
- # except Exception as e:
636
- # print("❌ Error loading model:", str(e))
637
- # return jsonify({"error": f"Could not load model: {str(e)}"}), 500
638
-
639
- # questions = {
640
- # "Patient Name": "What is the patient's name?",
641
- # "Age": "What is the patient's age?",
642
- # "Gender": "What is the patient's gender?",
643
- # "Date of Birth": "What is the patient's date of birth?",
644
- # "Patient ID": "What is the patient ID?",
645
- # "Reason for Visit": "What is the reason for the patient's visit?",
646
- # "Physician": "Who is the physician in charge of the patient?",
647
- # "Test Date": "What is the test date?",
648
- # "Hemoglobin": "What is the patient's hemoglobin level?",
649
- # "Blood Glucose (Fasting)": "What is the patient's fasting blood glucose level?",
650
- # "Total Cholesterol": "What is the total cholesterol level?",
651
- # "LDL Cholesterol": "What is the LDL cholesterol level?",
652
- # "HDL Cholesterol": "What is the HDL cholesterol level?",
653
- # "Serum Creatinine": "What is the serum creatinine level?",
654
- # "Vitamin D (25-OH)": "What is the patient's Vitamin D level?",
655
- # "Height": "What is the patient's height?",
656
- # "Weight": "What is the patient's weight?",
657
- # "Blood Pressure (Systolic)": "What is the patient's systolic blood pressure?",
658
- # "Blood Pressure (Diastolic)": "What is the patient's diastolic blood pressure?",
659
- # "Recommendations": "What are the recommendations based on the test results?"
660
- # }
661
-
662
- # structured_response = {"extracted_data": []}
663
-
664
- # for file_data in data["extracted_data"]:
665
- # filename = file_data["file"]
666
- # context = file_data["extracted_text"]
667
-
668
- # if not context:
669
- # structured_response["extracted_data"].append({
670
- # "file": filename,
671
- # "medical_terms": "No data extracted"
672
- # })
673
- # continue
674
-
675
- # # Prepare batch QA input
676
- # qa_inputs = [
677
- # {"question": q, "context": context}
678
- # for q in questions.values()
679
- # ]
680
-
681
- # try:
682
- # qa_outputs = qa_pipeline(qa_inputs)
683
- # print("📤 Batch QA outputs:", qa_outputs)
684
- # except Exception as e:
685
- # print("⚠️ Batch failed, falling back to loop:", str(e))
686
- # qa_outputs = [qa_pipeline(q) for q in qa_inputs]
687
-
688
- # # Map answers back to questions
689
- # extracted_info = {}
690
- # for i, key in enumerate(questions.keys()):
691
- # answer = qa_outputs[i].get("answer", "").strip()
692
- # score = qa_outputs[i].get("score", 0.0)
693
-
694
- # # If the model returns an empty string or very low confidence, mark as "Not Mentioned"
695
- # if not answer or score < 0.1:
696
- # extracted_info[key] = "Not Mentioned"
697
- # else:
698
- # extracted_info[key] = answer
699
-
700
- # # Optional: Clean results
701
- # # extracted_info = {k: clean_result(v) for k, v in extracted_info.items()}
702
-
703
- # categorized_data = [
704
- # {
705
- # "name": "Patient Information",
706
- # "fields": [
707
- # {"label": "Patient Name", "value": extracted_info.get("Patient Name", "")},
708
- # {"label": "Date of Birth", "value": extracted_info.get("Date of Birth", "")},
709
- # {"label": "Gender", "value": extracted_info.get("Gender", "")},
710
- # {"label": "Patient ID", "value": extracted_info.get("Patient ID", "")}
711
- # ]
712
- # },
713
- # {
714
- # "name": "Vitals",
715
- # "fields": [
716
- # {"label": "Height", "value": extracted_info.get("Height", "")},
717
- # {"label": "Weight", "value": extracted_info.get("Weight", "")},
718
- # {"label": "Blood Pressure", "value": f"{extracted_info.get('Blood Pressure (Systolic)', '')}/{extracted_info.get('Blood Pressure (Diastolic)', '')} mmHg"},
719
- # {"label": "Hemoglobin", "value": extracted_info.get("Hemoglobin", "")},
720
- # {"label": "Serum Creatinine", "value": extracted_info.get("Serum Creatinine", "")}
721
- # ]
722
- # },
723
- # {
724
- # "name": "Lab Results",
725
- # "fields": [
726
- # {"label": "Blood Glucose (Fasting)", "value": extracted_info.get("Blood Glucose (Fasting)", "")},
727
- # {"label": "Total Cholesterol", "value": extracted_info.get("Total Cholesterol", "")},
728
- # {"label": "LDL Cholesterol", "value": extracted_info.get("LDL Cholesterol", "")},
729
- # {"label": "HDL Cholesterol", "value": extracted_info.get("HDL Cholesterol", "")},
730
- # {"label": "Vitamin D (25-OH)", "value": extracted_info.get("Vitamin D (25-OH)", "")}
731
- # ]
732
- # },
733
- # {
734
- # "name": "Medical Notes",
735
- # "fields": [
736
- # {"label": "Reason for Visit", "value": extracted_info.get("Reason for Visit", "")},
737
- # {"label": "Physician", "value": extracted_info.get("Physician", "")},
738
- # {"label": "Test Date", "value": extracted_info.get("Test Date", "")},
739
- # {"label": "Recommendations", "value": extracted_info.get("Recommendations", "")}
740
- # ]
741
- # }
742
- # ]
743
-
744
- # structured_response["extracted_data"].append({
745
- # "file": filename,
746
- # "medical_terms": extracted_info,
747
- # "categorized_data": categorized_data
748
- # })
749
-
750
- # save_data_to_storage(filename, structured_response)
751
- # print(f"✅ Extracted data saved to: {os.path.join(UPLOAD_FOLDER, f'{filename}.json')}")
752
-
753
- # return jsonify(structured_response)
754
-
755
-
756
- # ------------------ CLEAN FUNCTION ------------------ #
757
- @log_execution_time()
758
- def clean_result(value):
759
- logger.debug("Cleaning value: %s", value)
760
- if isinstance(value, str):
761
- value = re.sub(r"\s+", " ", value)
762
- value = re.sub(r"[-_:]+", " ", value)
763
- value = re.sub(r"[^\x00-\x7F]+", " ", value)
764
- value = re.sub(
765
- r"(?<=\d),(?=\d)", "", value
766
- ) # Remove commas in numbers like 250,000
767
- return value.strip() if value.strip() else "Not Available"
768
- elif isinstance(value, list):
769
- cleaned = [clean_result(v) for v in value if v is not None]
770
- return cleaned if cleaned else ["Not Available"]
771
- elif isinstance(value, dict):
772
- return {k: clean_result(v) for k, v in value.items()}
773
- return value
774
-
775
- # ------------------Group by Category ------------------ #
776
- @log_execution_time()
777
- def group_by_category(data):
778
- logger.info("Grouping extracted items by category")
779
- grouped = defaultdict(list)
780
- category_times = {}
781
-
782
- for item in data:
783
- cat = item.get("category", "General")
784
- start_time = time.time()
785
- grouped[cat].append(
786
- {
787
- "question": item.get("question", "Not Created"),
788
- "label": item.get("label", "Unknown"),
789
- "answer": item.get("answer", "Not Available"),
790
- }
791
- )
792
- elapsed = time.time() - start_time
793
- category_times[cat] = category_times.get(cat, 0) + elapsed
794
-
795
- for cat, details in grouped.items():
796
- logger.info(f"📂 Category '{cat}': {len(details)} items, time taken: {category_times[cat]:.4f}s")
797
-
798
- return [{"category": k, "detail": v} for k, v in grouped.items()]
799
-
800
-
801
- # ------------------detect duplicate to remove it ------------------ #
802
- @log_execution_time()
803
- def deduplicate_extractions(data):
804
- logger.info("Deduplicating extracted data")
805
- seen = set()
806
- unique = []
807
- for item in data:
808
- # Use a tuple of key fields to detect duplicates
809
- key = (item.get("label"))
810
- if key not in seen:
811
- seen.add(key)
812
- unique.append(item)
813
- return unique
814
-
815
-
816
- # Load tokenizer outside the route for performance
817
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
818
-
819
-
820
- # -----------------------------Split text into overlapping chunks---------------#
821
- @log_execution_time()
822
- def chunk_text(text, tokenizer, max_tokens=512, overlap=50):
823
- """
824
- Splits text into overlapping token-based chunks without using NLTK.
825
-
826
- Args:
827
- text (str): Raw input text.
828
- tokenizer (transformers tokenizer): Hugging Face tokenizer instance.
829
- max_tokens (int): Max tokens per chunk.
830
- overlap (int): Number of overlapping tokens between chunks.
831
-
832
- Returns:
833
- List[str]: List of decoded text chunks.
834
- """
835
- # Tokenize the full text
836
- logger.info("Splitting text into chunks")
837
- input_ids = tokenizer.encode(text, add_special_tokens=False)
838
- chunks = []
839
- start = 0
840
- while start < len(input_ids):
841
- end = start + max_tokens
842
- chunk_ids = input_ids[start:end]
843
- chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True)
844
- # Ensure partial continuation isn't cut off mid-sentence
845
- if not chunk_text.endswith(('.', '?', '!', ':')):
846
- chunk_text += "..."
847
-
848
- chunks.append(chunk_text)
849
- start += max_tokens - overlap
850
- logger.info("Created %d chunks", len(chunks))
851
- return chunks
852
-
853
- # ------------------ PARSE JSON OBJECTS FROM OUTPUT ------------------ #
854
- @log_execution_time()
855
- def extract_json_objects(text):
856
- logger.info("Extracting JSON objects from text")
857
- extracted = []
858
- try:
859
- json_start = text.index('[')
860
- json_text = text[json_start:]
861
- except ValueError:
862
- logger.warning("⚠ '[' not found in output")
863
- return []
864
-
865
- # Try parsing full array first
866
- try:
867
- parsed = json.loads(json_text)
868
- if isinstance(parsed, list):
869
- return parsed
870
- except Exception:
871
- pass # fallback to manual parsing
872
-
873
- # Manual recovery via brace matching
874
- stack = 0
875
- obj_start = None
876
- for i, char in enumerate(json_text):
877
- if char == '{':
878
- if stack == 0:
879
- obj_start = i
880
- stack += 1
881
- elif char == '}':
882
- stack -= 1
883
- if stack == 0 and obj_start is not None:
884
- obj_str = json_text[obj_start:i+1]
885
- try:
886
- obj = json.loads(obj_str)
887
- extracted.append(obj)
888
- except Exception as e:
889
- logger.error(f"❌ Invalid JSON object: {e}")
890
- obj_start = None
891
-
892
- return extracted
893
-
894
-
895
- # ------------------ PROCESS A SINGLE CHUNK ------------------ #
896
- @log_execution_time()
897
- def process_chunk(generator, chunk, idx):
898
- logger.info("Processing chunk %d", idx + 1)
899
- prompt = f"""
900
- [INST] <<SYS>>
901
- You are a clinical data extraction assistant.
902
-
903
- Your job is to:
904
- 1. Read the following medical report.
905
- 2. Extract all medically relevant facts as a list of JSON objects.
906
- 3. Each object must include:
907
- - "label": a short field name (e.g., "blood pressure", "diagnosis")
908
- - "question": a question related to that field
909
- - "answer": the answer from the text
910
- 4. After extracting the list, categorize each object under one of the following fixed categories:
911
-
912
- - Patient Info
913
- - Vitals
914
- - Symptoms
915
- - Allergies
916
- - Habits
917
- - Comorbidities
918
- - Diagnosis
919
- - Medication
920
- - Laboratory
921
- - Radiology
922
- - Doctor Note
923
-
924
- Example format for structure only — do not include in output:
925
- [
926
- {{
927
- "label": "patient name",
928
- "question": "What is the patient's name?",
929
- "answer": "John Doe",
930
- "category": "Patient Info"
931
- }},
932
- {{
933
- "label": "heart rate",
934
- "question": "What is the heart rate?",
935
- "answer": "78 bpm",
936
- "category": "Vitals"
937
- }}
938
- ]
939
-
940
- ⚠ Use these categories listed above.If an item does not fit any of these categories, create a new category for it.
941
-
942
- Text:
943
- {chunk}
944
-
945
- Return a single valid JSON array of all extracted objects.
946
- Do not include any explanations or commentary.
947
- Only output the JSON array
948
- <</SYS>> [/INST]
949
- """
950
-
951
- try:
952
- output = generator(
953
- prompt,
954
- max_new_tokens=1024,
955
- do_sample=True,
956
- temperature=0.3
957
- )[0]["generated_text"]
958
- print("----------------------------------")
959
- logger.info(f"📤 Output from chunk {idx}: {output}...")
960
- return idx, output
961
- except Exception as e:
962
- logger.error("Error processing chunk %d: %s", idx, e)
963
- return idx, None
964
-
965
-
966
- # ------------------Extract Medical Data ------------------ #
967
- @app.route("/extract_medical_data", methods=["POST"])
968
- @log_execution_time()
969
- def extract_medical_data():
970
- data = request.json
971
- logger.info("Received request: %s", json.dumps(data, indent=2))
972
-
973
- qa_model_name = data.get("qa_model_name")
974
- qa_model_type = data.get("qa_model_type")
975
- extracted_files = data.get("extracted_data")
976
-
977
- if not qa_model_name or not qa_model_type:
978
- return jsonify({"error": "Missing 'qa_model_name' or 'qa_model_type'"}), 400
979
-
980
- if not extracted_files:
981
- return jsonify({"error": "Missing 'extracted_data' in request"}), 400
982
-
983
- try:
984
- logger.info(f"🌀 Loading model: {qa_model_name} ({qa_model_type})")
985
- model = AutoModelForCausalLM.from_pretrained(qa_model_name, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
986
- generator = pipeline(task=qa_model_type, model=model, tokenizer=tokenizer)
987
- logger.info(f"✅ Model loaded successfully: {generator.model.config._name_or_path}")
988
- except Exception as e:
989
- logger.error("❌ Model load failure")
990
- return jsonify({"error": f"Could not load model: {str(e)}"}), 500
991
-
992
- structured_response = {"extracted_data": []}
993
-
994
- for file_data in extracted_files:
995
- filename = file_data.get("file", "unknown_file")
996
- context = file_data.get("extracted_text", "").strip()
997
- logger.info("Processing file: %s", filename)
998
-
999
- if not context:
1000
- logger.warning("No text found in file: %s", filename)
1001
- structured_response["extracted_data"].append(
1002
- {"file": filename, "medical_fields": "No data extracted"}
1003
- )
1004
- continue
1005
-
1006
- chunks = chunk_text(context, tokenizer)
1007
- logger.info(f"📚 Chunked into {len(chunks)} parts for {filename}")
1008
-
1009
- all_extracted = []
1010
- # for idx,chunk in enumerate(chunks):
1011
- # print(f"Processing chunk {idx+1}/{len(chunks)}")
1012
-
1013
- with ThreadPoolExecutor(max_workers=4) as executor:
1014
- futures = {
1015
- executor.submit(process_chunk, generator, chunk, idx): idx
1016
- for idx, chunk in enumerate(chunks)
1017
- }
1018
- for future in as_completed(futures):
1019
- idx = futures[future]
1020
- _, output = future.result()
1021
-
1022
- if not output:
1023
- continue
1024
-
1025
- try:
1026
- objs = extract_json_objects(output)
1027
- if objs:
1028
- all_extracted.extend(objs)
1029
- else:
1030
- logger.error(f"⚠ Chunk {idx+1} yielded no valid JSON.")
1031
- except Exception as e:
1032
- logger.error(f"❌ Error extracting JSON from chunk {idx+1}")
1033
-
1034
- # Clean and group results for this file
1035
- if all_extracted:
1036
- deduped = deduplicate_extractions(all_extracted)
1037
- # cleaned_json = clean_result()
1038
- grouped_data = group_by_category(deduped)
1039
- else:
1040
- grouped_data = {"error": "No valid data extracted"}
1041
-
1042
- structured_response["extracted_data"].append(
1043
- {"file": filename, "medical_fields": grouped_data}
1044
- )
1045
-
1046
- try:
1047
- save_data_to_storage(filename, grouped_data)
1048
- except Exception as e:
1049
- logger.error(f"⚠ Failed to save data for {filename}: {e}")
1050
-
1051
- logger.info("✅ Extraction complete.")
1052
- return jsonify(structured_response)
1053
-
1054
-
1055
-
1056
- # -------------------------- save data to a JSON file----------------------#
1057
- @log_execution_time()
1058
- def save_data_to_storage(filename, data):
1059
- try:
1060
- filename = filename.rsplit(".", 1)[0] # Remove extension
1061
- filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json")
1062
- logger.info(f"💾 Saving to: {filepath}")
1063
- with open(filepath, "w") as file:
1064
- json.dump(data, file)
1065
- logger.info(f"✅ Data saved successfully to {filepath}")
1066
- except Exception as e:
1067
- logger.error(f"🚨 Exception during save: {e}")
1068
-
1069
-
1070
- # Function to get data from a JSON file
1071
- # 🔍 Get data from storage
1072
- @log_execution_time()
1073
- def get_data_from_storage(filename):
1074
- try:
1075
- filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json")
1076
- logger.info(f"🔍 Looking for file at: {filepath}")
1077
- if not os.path.exists(filepath):
1078
- logger.warning(f"🚫 File not found at: {filepath}")
1079
- return None
1080
- with open(filepath, "r") as file:
1081
- data = json.load(file)
1082
- logger.info(f"✅ File found and loaded: {filepath}")
1083
- return data
1084
- except Exception as e:
1085
- logger.error(f"🚨 Error loading data: {e}")
1086
- return None
1087
-
1088
-
1089
- # 🔹 Fetch updated medical data
1090
- @app.route("/get_updated_medical_data", methods=["GET"])
1091
- @log_execution_time()
1092
- def get_updated_data():
1093
- file_name = request.args.get("file")
1094
-
1095
- if not file_name:
1096
- return jsonify({"error": "File name is required"}), 400
1097
-
1098
- # 🔥 Strip extension if present
1099
- file_name = file_name.rsplit(".", 1)[0]
1100
-
1101
- # ✅ Load updated JSON data from storage
1102
- updated_data = get_data_from_storage(file_name)
1103
-
1104
- if updated_data:
1105
- return jsonify({"file": file_name, "data": updated_data}), 200
1106
- else:
1107
- return jsonify({"error": f"File '{file_name}' not found"}), 404
1108
-
1109
-
1110
-
1111
- @app.route("/update_medical_data", methods=["PUT"])
1112
- @log_execution_time()
1113
- def update_medical_data():
1114
- try:
1115
- data = request.json
1116
- logger.info("Received update: %s", json.dumps(data, indent=2))
1117
-
1118
- filename = data.get("file", "").rsplit(".", 1)[0] # Strip extension like .pdf
1119
- updates = data.get("updates", [])
1120
-
1121
- if not filename or not updates:
1122
- return jsonify({"error": "File name or updates missing"}), 400
1123
-
1124
- # Load current stored data
1125
- existing_data = get_data_from_storage(filename)
1126
- if not existing_data:
1127
- return jsonify({"error": f"File '{filename}' not found"}), 404
1128
-
1129
- # Loop through updates and modify categorized_data
1130
- for update in updates:
1131
- category = update.get("category")
1132
- field = update.get("field")
1133
- new_value = update.get("value")
1134
- updated = False
1135
-
1136
- for extracted in existing_data.get("extracted_data", []):
1137
- for cat in extracted.get("categorized_data", []):
1138
- if cat.get("name") == category:
1139
- for fld in cat.get("fields", []):
1140
- if fld.get("label") == field:
1141
- logger.info("Updating [%s] %s → %s", category, field, new_value)
1142
- fld["value"] = new_value
1143
- updated = True
1144
- break
1145
- if updated:
1146
- break
1147
- if updated:
1148
- break
1149
-
1150
- # 🧠 Sync medical_terms with categorized_data
1151
- for extracted in existing_data.get("extracted_data", []):
1152
- if "categorized_data" in extracted:
1153
- new_terms = {}
1154
- for category in extracted["categorized_data"]:
1155
- for field in category.get("fields", []):
1156
- label = field.get("label")
1157
- value = field.get("value", "")
1158
- new_terms[label] = value
1159
- extracted["medical_terms"] = new_terms
1160
- logger.info("Synced 'medical_terms' with 'categorized_data'")
1161
-
1162
- # Save updated data to file
1163
- save_data_to_storage(filename, existing_data)
1164
- logger.info("✅ Updated data saved successfully")
1165
-
1166
- return (
1167
- jsonify(
1168
- {"message": "Data updated successfully", "updated_data": existing_data}
1169
- ),
1170
- 200,
1171
- )
1172
-
1173
- except Exception as e:
1174
- logger.error("Update error: %s", e)
1175
- return jsonify({"error": str(e)}), 500
1176
-
1177
- # Test Route
1178
- @app.route("/")
1179
- def home():
1180
- return "Medical Data Extraction API is running!"
1181
-
1182
-
1183
- if __name__ == "__main__":
1184
- app.run(host="0.0.0.0", port=5000, debug=True)
1185
- # if __name__ == '__main__':
1186
- # from gevent.pywsgi import WSGIServer # type: ignore
1187
- # http_server = WSGIServer(('0.0.0.0', 5000), app)
1188
- # http_server.serve_forever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
speech_to_chart.py DELETED
@@ -1,638 +0,0 @@
1
-
2
-
3
- import json
4
- import os
5
- import re
6
- import logging
7
- import shutil
8
- from flask import Flask, request, jsonify, abort
9
- from werkzeug.utils import secure_filename
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
- import torch
12
- import whisper
13
- from dotenv import load_dotenv
14
- import pytesseract
15
- import cv2
16
- import pdfplumber
17
- import pandas as pd
18
- from PIL import Image
19
- from docx import Document
20
- from flask_cors import CORS
21
-
22
- # Load environment variables
23
- load_dotenv()
24
-
25
- # Initialize Flask app
26
- app = Flask(__name__)
27
- CORS(app)
28
-
29
- # Configure logging
30
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
31
-
32
- # Configure upload directory and max file size
33
- UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads'))
34
- os.makedirs(UPLOAD_DIR, exist_ok=True)
35
- app.config['UPLOAD_FOLDER'] = UPLOAD_DIR
36
- app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max file size
37
-
38
- # Allowed file extensions
39
- ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'flac'}
40
- ALLOWED_DOCUMENT_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'docx', 'xlsx', 'xls'}
41
-
42
- # Ensure ffmpeg is in PATH
43
- ffmpeg_path = shutil.which("ffmpeg") or "C:\\ffmpeg\\bin\\ffmpeg.exe"
44
- if not os.path.exists(ffmpeg_path):
45
- raise RuntimeError("FFmpeg not found! Please install FFmpeg and set the correct path.")
46
- os.environ["PATH"] += os.pathsep + os.path.dirname(ffmpeg_path)\
47
-
48
- def allowed_file(filename, allowed_extensions):
49
- """Check if the file extension is allowed."""
50
- return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions
51
-
52
- class LazyModelLoader:
53
- def __init__(self, model_name, task, tokenizer=None, apply_quantization=False):
54
- self.model_name = model_name
55
- self.task = task
56
- self.tokenizer = tokenizer
57
- self.apply_quantization = apply_quantization
58
- self._pipeline = None
59
-
60
- def load(self):
61
- if self._pipeline is None:
62
- logging.info(f"Loading pipeline for task: {self.task} | model: {self.model_name}")
63
- if self.task == "question-answering":
64
- model = AutoModelForCausalLM.from_pretrained(self.model_name)
65
- tokenizer = AutoTokenizer.from_pretrained(self.model_name)
66
- if self.apply_quantization:
67
- logging.info("Applying quantization...")
68
- model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
69
- self._pipeline = pipeline(self.task, model=model, tokenizer=tokenizer)
70
- else:
71
- self._pipeline = pipeline(self.task, model=self.model_name, tokenizer=self.tokenizer)
72
- return self._pipeline
73
-
74
- # PHI scrubbing agent
75
- class PHIScrubberAgent:
76
- @staticmethod
77
- def scrub_phi(text):
78
- try:
79
- text = re.sub(r'\b(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b', '[PHONE]', text)
80
- text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w{2,4}\b', '[EMAIL]', text)
81
- text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text)
82
- text = re.sub(r'\b\d{1,5}\s+\w+\s+(Street|St|Avenue|Ave|Boulevard|Blvd|Road|Rd|Lane|Ln)\b', '[ADDRESS]',
83
- text, flags=re.IGNORECASE)
84
- text = re.sub(r'\bDr\.?\s+[A-Z][a-z]+\s+[A-Z][a-z]+\b', 'Dr. [NAME]', text)
85
- text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text)
86
- except Exception as e:
87
- logging.error(f"PHI scrubbing failed: {e}")
88
- return text
89
-
90
- # Summarization Agent
91
- class SummarizerAgent:
92
- def __init__(self, summarization_model_loader):
93
- self.summarization_model_loader = summarization_model_loader
94
-
95
- def generate_summary(self, text):
96
- model = self.summarization_model_loader.load()
97
- try:
98
- summary_result = model(text, max_length=150, min_length=30, do_sample=False)
99
- return summary_result[0]['summary_text'].strip()
100
- except Exception as e:
101
- logging.error(f"Summary generation failed: {e}")
102
- return "Summary generation failed."
103
-
104
- # Medical Data Extraction Agent
105
- class MedicalDataExtractorAgent:
106
- def __init__(self, gen_model_loader):
107
- self.gen_model_loader = gen_model_loader
108
-
109
- def extract_medical_data(self, text):
110
- try:
111
- generator = self.gen_model_loader.load()
112
- prompt = (
113
- "Extract structured medical information from the following clinical note.\n\n"
114
- "Return the result in JSON format with the following fields:\n"
115
- "patient_condition, symptoms, current_problems, allergies, dr_notes, "
116
- "prescription, investigations, follow_up_instructions.\n\n"
117
- f"Clinical Note:\n{text}\n\n"
118
- "Structured JSON Output:\n"
119
- )
120
- response = generator(prompt, max_new_tokens=256)[0]["generated_text"]
121
- logging.debug(f"Raw model output: {response}")
122
-
123
- json_start = response.find("{")
124
- json_end = response.rfind("}") + 1
125
- if json_start == -1 or json_end == -1:
126
- raise ValueError("No JSON found in the model response.")
127
-
128
- json_str = response[json_start:json_end]
129
- return json.loads(json_str)
130
-
131
- except Exception as e:
132
- logging.error(f"Error extracting medical data: {e}")
133
- return {"error": f"Failed to extract medical data: {str(e)}"}
134
-
135
- # Initialize lazy loaders
136
- gen_model_loader = LazyModelLoader(
137
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
138
- "text-generation",
139
- )
140
- summarization_model_loader = LazyModelLoader("google-t5/t5-large", "summarization", apply_quantization=True)
141
- whisper_model = whisper.load_model("base")
142
-
143
- # Initialize agents
144
- phi_scrubber_agent = PHIScrubberAgent()
145
- medical_data_extractor_agent = MedicalDataExtractorAgent(gen_model_loader)
146
- summarizer_agent = SummarizerAgent(summarization_model_loader)
147
-
148
- # API Endpoints
149
- @app.route('/api/extract_medical_data', methods=['POST'])
150
- def extract_medical_data():
151
- try:
152
- data = request.json
153
- if "text" not in data or not data["text"].strip():
154
- return jsonify({"error": "No valid text provided"}), 400
155
- raw_text = data["text"]
156
- clean_text = phi_scrubber_agent.scrub_phi(raw_text)
157
- structured_data = medical_data_extractor_agent.extract_medical_data(clean_text)
158
- return jsonify(structured_data), 200
159
- except Exception as e:
160
- logging.error(f"Failed to extract medical data: {e}")
161
- return jsonify({"error": f"Extraction Error: {str(e)}"}), 500
162
-
163
- @app.route('/api/transcribe', methods=['POST'])
164
- def transcribe_audio():
165
- if 'audio' not in request.files:
166
- abort(400, description="No audio file provided")
167
- audio_file = request.files['audio']
168
- if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
169
- abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
170
- filename = secure_filename(audio_file.filename)
171
- audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
172
- audio_file.save(audio_path)
173
- try:
174
- result = whisper_model.transcribe(audio_path)
175
- transcribed_text = result["text"]
176
- os.remove(audio_path)
177
- return jsonify({"transcribed_text": transcribed_text}), 200
178
- except Exception as e:
179
- logging.error(f"Transcription failed: {str(e)}")
180
- return jsonify({"error": f"Transcription failed: {str(e)}"}), 500
181
-
182
- @app.route('/api/generate_summary', methods=['POST'])
183
- def generate_summary():
184
- data = request.json
185
- if "text" not in data or not data["text"].strip():
186
- return jsonify({"error": "No valid text provided"}), 400
187
- context = data["text"]
188
- clean_text = phi_scrubber_agent.scrub_phi(context)
189
- summary = summarizer_agent.generate_summary(clean_text)
190
- return jsonify({"summary": summary}), 200
191
-
192
- @app.route('/api/extract_medical_data_from_audio', methods=['POST'])
193
- def extract_medical_data_from_audio():
194
- if 'audio' not in request.files:
195
- abort(400, description="No audio file provided")
196
- audio_file = request.files['audio']
197
- if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
198
- abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
199
- filename = secure_filename(audio_file.filename)
200
- audio_path = os.path.join(UPLOAD_DIR, filename)
201
- audio_file.save(audio_path)
202
- try:
203
- result = whisper_model.transcribe(audio_path)
204
- transcribed_text = result["text"]
205
- clean_text = phi_scrubber_agent.scrub_phi(transcribed_text)
206
- summary = summarizer_agent.generate_summary(clean_text)
207
- structured_data = medical_data_extractor_agent.extract_medical_data(clean_text)
208
- response = {
209
- "transcribed_text": clean_text,
210
- "summary": summary,
211
- "medical_chart": structured_data
212
- }
213
- os.remove(audio_path)
214
- return jsonify(response), 200
215
- except Exception as e:
216
- logging.error(f"Processing failed: {str(e)}")
217
- return jsonify({"error": f"Processing failed: {str(e)}"}), 500
218
-
219
- if __name__ == '__main__':
220
- app.run(host='0.0.0.0', port=5000, debug=False)
221
-
222
- # import json
223
- # import os
224
- # import re
225
- # import logging
226
- # import shutil
227
- # from dotenv import load_dotenv
228
- # from flask import Flask, request, jsonify, abort
229
- # from werkzeug.utils import secure_filename
230
- # from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
231
- # import pytesseract
232
- # import cv2
233
- # import pdfplumber
234
- # import pandas as pd
235
- # from PIL import Image
236
- # from docx import Document
237
- # from flask_cors import CORS
238
- # from flask_executor import Executor
239
- # from sentence_transformers import SentenceTransformer
240
- # import faiss
241
- # import whisper
242
- # from PyPDF2 import PdfReader
243
- # from pdf2image import convert_from_path
244
- # from concurrent.futures import ThreadPoolExecutor
245
- # import tempfile
246
-
247
- # # Load environment variables
248
- # load_dotenv()
249
-
250
- # # Initialize Flask app
251
- # app = Flask(__name__)
252
- # CORS(app)
253
-
254
- # # Configure logging
255
- # logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
256
-
257
- # # Configure upload directory and max file size
258
- # UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads'))
259
- # os.makedirs(UPLOAD_DIR, exist_ok=True)
260
- # app.config['UPLOAD_FOLDER'] = UPLOAD_DIR
261
- # app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max file size
262
-
263
- # # Initialize Flask-Executor for asynchronous tasks
264
- # executor = Executor(app)
265
- # whisper_model = whisper.load_model("tiny")
266
- # # Allowed file extensions
267
- # ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'flac'}
268
- # ALLOWED_DOCUMENT_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'docx', 'xlsx', 'xls'}
269
-
270
- # # Ensure ffmpeg is in PATH
271
- # ffmpeg_path = shutil.which("ffmpeg") or "C:\\ffmpeg\\bin\\ffmpeg.exe"
272
- # if not os.path.exists(ffmpeg_path):
273
- # raise RuntimeError("FFmpeg not found! Please install FFmpeg and set the correct path.")
274
- # os.environ["PATH"] += os.pathsep + os.path.dirname(ffmpeg_path)
275
-
276
- # # Lazy model loading to save resources
277
- # class LazyModelLoader:
278
- # def __init__(self, model_name, task, tokenizer=None):
279
- # self.model_name = model_name
280
- # self.task = task
281
- # self.tokenizer = tokenizer
282
- # self._model = None
283
-
284
- # def load(self):
285
- # """Load the model if not already loaded."""
286
- # if self._model is None:
287
- # logging.info(f"Loading model: {self.model_name}")
288
- # if self.task == "text-generation":
289
- # self._model = AutoModelForCausalLM.from_pretrained(
290
- # self.model_name, device_map="auto", torch_dtype="auto"
291
- # )
292
- # self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, legacy=False)
293
- # # Set pad_token_id if it's not already set
294
- # if self._model.generation_config.pad_token_id is None or self._model.generation_config.pad_token_id < 0:
295
- # if self._tokenizer.eos_token_id is not None:
296
- # self._model.generation_config.pad_token_id = self._tokenizer.eos_token_id
297
- # logging.info(f"Set pad_token_id to {self._tokenizer.eos_token_id}")
298
- # else:
299
- # logging.warning("No valid eos_token_id found. Setting pad_token_id to 0 as a fallback.")
300
- # self._model.generation_config.pad_token_id = 0
301
- # else:
302
- # self._model = pipeline(self.task, model=self.model_name, tokenizer=self.tokenizer)
303
- # return self._model
304
-
305
- # # Text extraction agents
306
- # class TextExtractorAgent:
307
- # @staticmethod
308
- # def extract_text(filepath, ext):
309
- # """Extract text based on file type."""
310
- # try:
311
- # if ext == "pdf":
312
- # return TextExtractorAgent.extract_text_from_pdf(filepath)
313
- # elif ext in {"jpg", "jpeg", "png"}:
314
- # return TextExtractorAgent.extract_text_from_image(filepath)
315
- # elif ext == "docx":
316
- # return TextExtractorAgent.extract_text_from_docx(filepath)
317
- # elif ext in {"xlsx", "xls"}:
318
- # return TextExtractorAgent.extract_text_from_excel(filepath)
319
- # return None
320
- # except Exception as e:
321
- # logging.error(f"Text extraction failed: {e}")
322
- # return None
323
-
324
- # @staticmethod
325
- # def extract_text_from_pdf(filepath):
326
- # """Extract text from a PDF file."""
327
- # text = ""
328
- # with pdfplumber.open(filepath) as pdf:
329
- # for page in pdf.pages:
330
- # page_text = page.extract_text()
331
- # if page_text:
332
- # text += page_text + "\n"
333
- # return text.strip() or None
334
-
335
- # @staticmethod
336
- # def extract_text_from_image(filepath):
337
- # """Extract text from an image using OCR."""
338
- # image = cv2.imread(filepath)
339
- # gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
340
- # _, processed = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
341
- # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
342
- # processed_path = temp_file.name
343
- # cv2.imwrite(processed_path, processed)
344
- # text = pytesseract.image_to_string(Image.open(processed_path), lang='eng')
345
- # os.remove(processed_path)
346
- # return text.strip() or None
347
-
348
- # @staticmethod
349
- # def extract_text_from_docx(filepath):
350
- # """Extract text from a DOCX file."""
351
- # doc = Document(filepath)
352
- # text = "\n".join([para.text for para in doc.paragraphs])
353
- # return text.strip() or None
354
-
355
- # @staticmethod
356
- # def extract_text_from_excel(filepath):
357
- # """Extract text from an Excel file."""
358
- # dfs = pd.read_excel(filepath, sheet_name=None)
359
- # text = "\n".join([
360
- # "\n".join([
361
- # " ".join(map(str, df[col].dropna()))
362
- # for col in df.columns
363
- # ])
364
- # for df in dfs.values()
365
- # ])
366
- # return text.strip() or None
367
-
368
- # # PHI scrubbing agent
369
- # class PHIScrubberAgent:
370
- # @staticmethod
371
- # def scrub_phi(text):
372
- # """Remove sensitive personal health information (PHI)."""
373
- # try:
374
- # text = re.sub(r'\b(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b', '[PHONE]', text)
375
- # text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w{2,4}\b', '[EMAIL]', text)
376
- # text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text)
377
- # text = re.sub(r'\b\d{1,5}\s+\w+\s+(Street|St|Avenue|Ave|Boulevard|Blvd|Road|Rd|Lane|Ln)\b', '[ADDRESS]', text, flags=re.IGNORECASE)
378
- # text = re.sub(r'\bDr\.?\s+[A-Z][a-z]+\s+[A-Z][a-z]+\b', 'Dr. [NAME]', text)
379
- # text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text)
380
- # except Exception as e:
381
- # logging.error(f"PHI scrubbing failed: {e}")
382
- # return text
383
-
384
- # # Summarization agent
385
- # class SummarizerAgent:
386
- # def __init__(self, summarization_model_loader):
387
- # self.summarization_model_loader = summarization_model_loader
388
-
389
- # def generate_summary(self, text):
390
- # """Generate a summary of the provided text."""
391
- # model = self.summarization_model_loader.load()
392
- # try:
393
- # summary_result = model(text, do_sample=False)
394
- # return summary_result[0]['summary_text'].strip()
395
- # except Exception as e:
396
- # logging.error(f"Summary generation failed: {e}")
397
- # return "Summary generation failed."
398
-
399
- # def allowed_file(filename, allowed_extensions):
400
- # """Check if the file extension is allowed."""
401
- # return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions
402
- # # Knowledge Base
403
- # class KnowledgeBase:
404
- # def __init__(self, documents):
405
- # self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
406
- # self.documents = documents
407
- # self.embeddings = self.embedding_model.encode(documents)
408
- # self.dimension = self.embedding_model.get_sentence_embedding_dimension()
409
- # self.index = faiss.IndexFlatL2(self.dimension)
410
- # self.index.add(self.embeddings)
411
-
412
- # def retrieve_relevant_info(self, query, top_k=3):
413
- # """Retrieve relevant medical information from the knowledge base."""
414
- # query_embedding = self.embedding_model.encode([query])
415
- # distances, indices = self.index.search(query_embedding, top_k)
416
- # relevant_texts = [self.documents[i] for i in indices[0]]
417
- # return relevant_texts
418
- # # Medical data extraction agent
419
- # class MedicalDataExtractorAgent:
420
- # def __init__(self, model_loader, knowledge_base):
421
- # self.model_loader = model_loader
422
- # self.knowledge_base = knowledge_base
423
-
424
- # def retrieve_relevant_info(self, query, top_k=3):
425
- # """Retrieve relevant medical information from the knowledge base."""
426
- # query_embedding = self.knowledge_base.embedding_model.encode([query])
427
- # distances, indices = self.knowledge_base.index.search(query_embedding, top_k)
428
- # relevant_texts = [self.knowledge_base.documents[i] for i in indices[0]]
429
- # return relevant_texts
430
-
431
- # def extract_medical_data(self, text):
432
- # """Extract structured medical data from text using Agentic RAG."""
433
- # try:
434
- # # Define the default JSON schema
435
- # default_schema = {
436
- # "patient_name": "[NAME]",
437
- # "age": None,
438
- # "gender": None,
439
- # "diagnosis": [],
440
- # "symptoms": [],
441
- # "medications": [],
442
- # "allergies": [],
443
- # "vitals": {
444
- # "blood_pressure": None,
445
- # "heart_rate": None,
446
- # "temperature": None
447
- # },
448
- # "notes": ""
449
- # }
450
- # # Construct the prompt with the input text
451
- # prompt = f"""
452
- # ### Instruction:
453
- # Extract structured medical data from the following text as a JSON whose parameters are enclosed in "" and without any \.
454
- # The JSON should include patientname, age, gender, medications, allergies, diagnosis, symptoms, vitals, and notes.
455
- # ### Text:
456
- # {text}
457
- # ### Response:
458
- # """
459
- # # Tokenize and generate the response
460
- # model = self.model_loader.load()
461
- # tokenizer = self.model_loader._tokenizer
462
- # inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
463
- # outputs = model.generate(
464
- # inputs.input_ids,
465
- # num_return_sequences=1,
466
- # temperature=0.7,
467
- # top_p=0.9,
468
- # do_sample=True
469
- # )
470
- # response = tokenizer.decode(outputs[0], skip_special_tokens=True)
471
- # logging.info(f"Model response: {response}")
472
- # # Parse and normalize the JSON output
473
- # json_start = response.find("{")
474
- # json_end = response.rfind("}") + 1
475
- # if json_start == -1 or json_end == -1:
476
- # raise ValueError("No JSON found in the model response.")
477
- # # Extract the JSON substring
478
- # structured_data = json.loads(response[json_start:json_end])
479
- # # Normalize the JSON output
480
- # normalized_data = self.normalize_json_output(structured_data, default_schema)
481
- # # Ensure blood pressure is a string
482
- # if normalized_data["vitals"]["blood_pressure"] and isinstance(normalized_data["vitals"]["blood_pressure"], str):
483
- # normalized_data["vitals"]["blood_pressure"] = normalized_data["vitals"]["blood_pressure"].strip('"')
484
- # return json.dumps(normalized_data)
485
- # except json.JSONDecodeError as e:
486
- # logging.error(f"JSON parsing failed: {e}")
487
- # return json.dumps({"error": f"Failed to parse JSON: {str(e)}"})
488
- # except Exception as e:
489
- # logging.error(f"Error extracting medical data: {e}")
490
- # return json.dumps({"error": f"Failed to extract medical data: {str(e)}"})
491
-
492
- # @staticmethod
493
- # def normalize_json_output(model_output, default_schema):
494
- # """
495
- # Normalize the model's JSON output to match the default schema.
496
- # """
497
- # try:
498
- # normalized_output = default_schema.copy()
499
- # for key in normalized_output:
500
- # if key in model_output:
501
- # normalized_output[key] = model_output[key]
502
- # return normalized_output
503
- # except Exception as e:
504
- # logging.error(f"Failed to normalize JSON: {e}")
505
- # return default_schema # Return the default schema in case of errors
506
-
507
- # # Initialize lazy loaders
508
- # medalpaca_model_loader = LazyModelLoader("lmsys/vicuna-7b-v1.5", "text-generation")
509
- # summarization_model_loader = LazyModelLoader("google-t5/t5-small", "summarization")
510
- # whisper_model = whisper.load_model("tiny")
511
-
512
- # # Initialize knowledge base
513
- # medical_documents = [
514
- # "Hypertension is a chronic condition characterized by elevated blood pressure.",
515
- # "Diabetes is a metabolic disorder that affects blood sugar levels.",
516
- # "Common symptoms of chest pain include pressure, tightness, or discomfort in the chest."
517
- # ]
518
- # knowledge_base = KnowledgeBase(medical_documents)
519
-
520
- # # Initialize agents
521
- # text_extractor_agent = TextExtractorAgent()
522
- # phi_scrubber_agent = PHIScrubberAgent()
523
- # medical_data_extractor_agent = MedicalDataExtractorAgent(medalpaca_model_loader, knowledge_base)
524
- # summarizer_agent = SummarizerAgent(summarization_model_loader)
525
-
526
- # # API Endpoints
527
- # @app.route('/api/extract_medical_data', methods=['POST'])
528
- # def extract_medical_data():
529
- # """Extract structured medical data from raw text."""
530
- # try:
531
- # data = request.json
532
- # if "text" not in data or not data["text"].strip():
533
- # return jsonify({"error": "No valid text provided"}), 400
534
- # raw_text = data["text"]
535
- # clean_text = phi_scrubber_agent.scrub_phi(raw_text)
536
- # structured_data = medical_data_extractor_agent.extract_medical_data(clean_text)
537
- # return jsonify(json.loads(structured_data)), 200
538
- # except Exception as e:
539
- # logging.error(f"Failed to extract medical data: {e}")
540
- # return jsonify({"error": f"Extraction Error: {str(e)}"}), 500
541
-
542
- # @app.route('/api/transcribe', methods=['POST'])
543
- # def transcribe_audio():
544
- # """Transcribe audio files into text."""
545
- # if 'audio' not in request.files:
546
- # abort(400, description="No audio file provided")
547
- # audio_file = request.files['audio']
548
- # if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
549
- # abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
550
- # filename = secure_filename(audio_file.filename)
551
- # audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
552
- # audio_file.save(audio_path)
553
- # try:
554
- # result = whisper_model.transcribe(audio_path)
555
- # transcribed_text = result["text"]
556
- # os.remove(audio_path)
557
- # return jsonify({"transcribed_text": transcribed_text}), 200
558
- # except Exception as e:
559
- # logging.error(f"Transcription failed: {str(e)}")
560
- # return jsonify({"error": f"Transcription failed: {str(e)}"}), 500
561
-
562
- # @app.route('/api/generate_summary', methods=['POST'])
563
- # def generate_summary():
564
- # """Generate a summary from the provided text."""
565
- # data = request.json
566
- # if "text" not in data or not data["text"].strip():
567
- # return jsonify({"error": "No valid text provided"}), 400
568
- # context = data["text"]
569
- # clean_text = phi_scrubber_agent.scrub_phi(context)
570
- # summary = summarizer_agent.generate_summary(clean_text)
571
- # return jsonify({"summary": summary}), 200
572
-
573
- # @app.route('/api/extract_medical_data_from_audio', methods=['POST'])
574
- # def extract_medical_data_from_audio():
575
- # """Extract medical data from transcribed audio."""
576
- # if 'audio' not in request.files:
577
- # abort(400, description="No audio file provided")
578
- # audio_file = request.files['audio']
579
- # if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
580
- # abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
581
- # filename = secure_filename(audio_file.filename)
582
- # audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
583
- # audio_file.save(audio_path)
584
- # try:
585
- # result = whisper_model.transcribe(audio_path)
586
- # transcribed_text = result["text"]
587
- # clean_text = phi_scrubber_agent.scrub_phi(transcribed_text)
588
- # summary = summarizer_agent.generate_summary(transcribed_text)
589
- # structured_data = medical_data_extractor_agent.extract_medical_data(transcribed_text)
590
- # response = {
591
- # "transcribed_text": transcribed_text,
592
- # "summary": summary,
593
- # "medical_chart": json.loads(structured_data)
594
- # }
595
- # os.remove(audio_path)
596
- # return jsonify(response), 200
597
- # except Exception as e:
598
- # logging.error(f"Processing failed: {str(e)}")
599
- # return jsonify({"error": f"Processing failed: {str(e)}"}), 500
600
-
601
- # @app.route('/upload_document', methods=['POST'])
602
- # def upload_document():
603
- # """Upload and extract text from documents."""
604
- # if 'file' not in request.files:
605
- # return jsonify({"error": "No file uploaded"}), 400
606
- # file = request.files['file']
607
- # if file.filename == '':
608
- # return jsonify({"error": "No file selected"}), 400
609
- # if file and allowed_file(file.filename, ALLOWED_DOCUMENT_EXTENSIONS):
610
- # filename = secure_filename(file.filename)
611
- # filepath = os.path.join(UPLOAD_DIR, filename)
612
- # file.save(filepath)
613
- # ext = filename.rsplit('.', 1)[1].lower()
614
- # extracted_text = text_extractor_agent.extract_text(filepath, ext)
615
- # if not extracted_text:
616
- # return jsonify({"error": "No text found in file."}), 400
617
- # response_data = {
618
- # "file": filename,
619
- # "extracted_text": extracted_text[:500],
620
- # "message": "Click to extract medical terms"
621
- # }
622
- # os.remove(filepath)
623
- # return jsonify(response_data), 200
624
- # return jsonify({"error": "Invalid file type"}), 400
625
-
626
- # @app.route('/extract_medical_data_from_document', methods=['POST'])
627
- # def extract_medical_data_from_document():
628
- # """Extract medical data from document text."""
629
- # data = request.json
630
- # if "text" not in data or not data["text"].strip():
631
- # return jsonify({"error": "No valid text provided"}), 400
632
- # context = data["text"]
633
- # clean_text = phi_scrubber_agent.scrub_phi(context)
634
- # structured_data = medical_data_extractor_agent.extract_medical_data(clean_text)
635
- # return jsonify(json.loads(structured_data)), 200
636
-
637
- # if __name__ == '__main__':
638
- # app.run(host='0.0.0.0', port=5000, debug=True)