sachinchandrankallar commited on
Commit
4abf821
·
1 Parent(s): 89a714b

Initial commit to Hugging Face Space

Browse files
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use a lightweight Python base image
2
+ FROM python:3.10-slim
3
+
4
+ # Install system dependencies
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ pkg-config \
8
+ libsystemd-dev \
9
+ libcairo2-dev \
10
+ tesseract-ocr \
11
+ libglib2.0-0 \
12
+ libsm6 \
13
+ libxrender1 \
14
+ libxext6 \
15
+ poppler-utils \
16
+ gettext \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+
20
+ # Set the working directory
21
+ WORKDIR /app
22
+
23
+ # Copy only dependency files first for better caching
24
+ COPY requirements.txt .
25
+
26
+ # Install pip and dependencies
27
+ RUN pip install --upgrade pip \
28
+ && pip install torch==2.6.0 --no-cache-dir \
29
+ && pip install -r requirements.txt --no-cache-dir
30
+
31
+ # Copy rest of your code (this is after deps so doesn't bust cache)
32
+ COPY . .
33
+
34
+ # Expose port 7860 (required by HF Spaces)
35
+ EXPOSE 7860
36
+
37
+ # Run the Flask app
38
+ CMD ["python", "-m", "ai_med_extract", "--port=7860", "--host=0.0.0.0"]
ai_med_extract/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ai_med_extract/__init__.py
ai_med_extract/__main__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .app import app
2
+
3
+ # Entrypoint for running the app as a module
4
+ if __name__ == "__main__":
5
+ app.run(host="0.0.0.0", port=5000, debug=True)
ai_med_extract/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (134 Bytes). View file
 
ai_med_extract/__pycache__/__main__.cpython-313.pyc ADDED
Binary file (305 Bytes). View file
 
ai_med_extract/__pycache__/app.cpython-313.pyc ADDED
Binary file (2.19 kB). View file
 
ai_med_extract/agents/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ai_med_extract/agents/__init__.py
ai_med_extract/agents/medical_data_extractor.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import json
3
+
4
+ class MedicalDataExtractorAgent:
5
+ def __init__(self, gen_model_loader):
6
+ self.gen_model_loader = gen_model_loader
7
+
8
+ def extract_medical_data(self, text):
9
+ try:
10
+ generator = self.gen_model_loader.load()
11
+ prompt = (
12
+ "Extract structured medical information from the following clinical note.\n\n"
13
+ "Return the result in JSON format with the following fields:\n"
14
+ "patient_condition, symptoms, current_problems, allergies, dr_notes, "
15
+ "prescription, investigations, follow_up_instructions.\n\n"
16
+ f"Clinical Note:\n{text}\n\n"
17
+ "Structured JSON Output:\n"
18
+ )
19
+ response = generator(prompt, max_new_tokens=256)[0]["generated_text"]
20
+ logging.debug(f"Raw model output: {response}")
21
+ json_start = response.find("{")
22
+ json_end = response.rfind("}") + 1
23
+ if json_start == -1 or json_end == -1:
24
+ raise ValueError("No JSON found in the model response.")
25
+ json_str = response[json_start:json_end]
26
+ return json.loads(json_str)
27
+ except Exception as e:
28
+ logging.error(f"Error extracting medical data: {e}")
29
+ return {"error": f"Failed to extract medical data: {str(e)}"}
ai_med_extract/agents/phi_scrubber.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import logging
3
+
4
+ class PHIScrubberAgent:
5
+ @staticmethod
6
+ def scrub_phi(text):
7
+ try:
8
+ text = re.sub(r'\b(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b', '[PHONE]', text)
9
+ text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w{2,4}\b', '[EMAIL]', text)
10
+ text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text)
11
+ text = re.sub(r'\b\d{1,5}\s+\w+\s+(Street|St|Avenue|Ave|Boulevard|Blvd|Road|Rd|Lane|Ln)\b', '[ADDRESS]', text, flags=re.IGNORECASE)
12
+ text = re.sub(r'\bDr\.?\s+[A-Z][a-z]+\s+[A-Z][a-z]+\b', 'Dr. [NAME]', text)
13
+ text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text)
14
+ except Exception as e:
15
+ logging.error(f"PHI scrubbing failed: {e}")
16
+ return text
ai_med_extract/agents/summarizer.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ class SummarizerAgent:
4
+ def __init__(self, summarization_model_loader):
5
+ self.summarization_model_loader = summarization_model_loader
6
+
7
+ def generate_summary(self, text):
8
+ model = self.summarization_model_loader.load()
9
+ try:
10
+ summary_result = model(text, max_length=150, min_length=30, do_sample=False)
11
+ return summary_result[0]['summary_text'].strip()
12
+ except Exception as e:
13
+ logging.error(f"Summary generation failed: {e}")
14
+ return "Summary generation failed."
ai_med_extract/agents/text_extractor.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdfplumber
2
+ import pytesseract
3
+ import cv2
4
+ import pandas as pd
5
+ from PIL import Image
6
+ from docx import Document
7
+ import tempfile
8
+ import os
9
+ import logging
10
+
11
+ class TextExtractorAgent:
12
+ @staticmethod
13
+ def extract_text(filepath, ext):
14
+ try:
15
+ if ext == "pdf":
16
+ return TextExtractorAgent.extract_text_from_pdf(filepath)
17
+ elif ext in {"jpg", "jpeg", "png"}:
18
+ return TextExtractorAgent.extract_text_from_image(filepath)
19
+ elif ext == "docx":
20
+ return TextExtractorAgent.extract_text_from_docx(filepath)
21
+ elif ext in {"xlsx", "xls"}:
22
+ return TextExtractorAgent.extract_text_from_excel(filepath)
23
+ return None
24
+ except Exception as e:
25
+ logging.error(f"Text extraction failed: {e}")
26
+ return None
27
+
28
+ @staticmethod
29
+ def extract_text_from_pdf(filepath, password=None):
30
+ text = ""
31
+ with pdfplumber.open(filepath) as pdf:
32
+ for page in pdf.pages:
33
+ page_text = page.extract_text()
34
+ if page_text:
35
+ text += page_text + "\n"
36
+ return text.strip() or None
37
+
38
+ @staticmethod
39
+ def extract_text_from_image(filepath):
40
+ image = cv2.imread(filepath)
41
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
42
+ _, processed = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
43
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
44
+ processed_path = temp_file.name
45
+ cv2.imwrite(processed_path, processed)
46
+ text = pytesseract.image_to_string(Image.open(processed_path), lang='eng')
47
+ os.remove(processed_path)
48
+ return text.strip() or None
49
+
50
+ @staticmethod
51
+ def extract_text_from_docx(filepath):
52
+ doc = Document(filepath)
53
+ text = "\n".join([para.text for para in doc.paragraphs])
54
+ return text.strip() or None
55
+
56
+ @staticmethod
57
+ def extract_text_from_excel(filepath):
58
+ dfs = pd.read_excel(filepath, sheet_name=None)
59
+ text = "\n".join([
60
+ "\n".join([
61
+ " ".join(map(str, df[col].dropna()))
62
+ for col in df.columns
63
+ ])
64
+ for df in dfs.values()
65
+ ])
66
+ return text.strip() or None
ai_med_extract/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ai_med_extract/api/__init__.py
ai_med_extract/api/routes.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import request, jsonify, abort, current_app
2
+ from . import app
3
+ from ..agents.text_extractor import TextExtractorAgent
4
+ from ..agents.phi_scrubber import PHIScrubberAgent
5
+ from ..agents.summarizer import SummarizerAgent
6
+ from ..agents.medical_data_extractor import MedicalDataExtractorAgent
7
+ from ..utils.file_utils import allowed_file, check_file_size, save_data_to_storage, get_data_from_storage
8
+ from ..utils.validation import clean_result, validate_patient_name
9
+ import os
10
+ import logging
11
+
12
+ @app.route("/upload", methods=["POST"])
13
+ def upload_file():
14
+ files = request.files.getlist("file")
15
+ patient_name = request.form.get("patient_name", "").strip()
16
+ password = request.form.get("password")
17
+ qa_model_name = request.form.get("qa_model_name")
18
+ qa_model_type = request.form.get("qa_model_type")
19
+ ner_model_name = request.form.get("ner_model_name")
20
+ ner_model_type = request.form.get("ner_model_type")
21
+ summarizer_model_name = request.form.get("summarizer_model_name")
22
+ summarizer_model_type = request.form.get("summarizer_model_type")
23
+ if not files:
24
+ return jsonify({"error": "No file uploaded"}), 400
25
+ # Model loading (example, adjust as needed)
26
+ try:
27
+ qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name)
28
+ except Exception as e:
29
+ return jsonify({"error": f"QA model load failed: {str(e)}"}), 500
30
+ try:
31
+ ner_pipeline = pipeline(task=ner_model_type, model=ner_model_name)
32
+ except Exception as e:
33
+ return jsonify({"error": f"NER model load failed: {str(e)}"}), 500
34
+ try:
35
+ summarizer_pipeline = pipeline(task=summarizer_model_type, model=summarizer_model_name)
36
+ except Exception as e:
37
+ return jsonify({"error": f"Summarizer model load failed: {str(e)}"}), 500
38
+ extracted_data = []
39
+ for file in files:
40
+ if file.filename == '':
41
+ continue
42
+ if not allowed_file(file.filename):
43
+ return jsonify({"error": f"Unsupported file type: {file.filename}. Supported file types are: {', '.join(allowed_file.ALLOWED_EXTENSIONS)}"}), 400
44
+ if not patient_name:
45
+ return jsonify({"error": "Patient name is missing"}), 400
46
+ valid_size, error_message = check_file_size(file)
47
+ if not valid_size:
48
+ return jsonify({"error": error_message}), 400
49
+ filename = file.filename
50
+ filepath = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
51
+ file.save(filepath)
52
+ ext = filename.rsplit('.', 1)[-1].lower()
53
+ extracted_text = TextExtractorAgent.extract_text(filepath, ext)
54
+ if not extracted_text or extracted_text == "No text found":
55
+ return jsonify({"error": f"Failed to extract text from {filename}"}), 415
56
+ skip_medical_check = request.form.get("skip_medical_check", "false").lower() == "true"
57
+ if not skip_medical_check:
58
+ ner_results = ner_pipeline(extracted_text)
59
+ medical_entities = list(set([r["word"] for r in ner_results if r["entity"].startswith("B-") or r["entity"].startswith("I-")]))
60
+ if not medical_entities:
61
+ return jsonify({"error": f"'{filename}' is not medically relevant"}), 406
62
+ skip_patient_check = request.form.get("skip_patient_check", "false").lower() == "true"
63
+ if not skip_patient_check:
64
+ try:
65
+ error_response = validate_patient_name(extracted_text, patient_name, filename, qa_pipeline)
66
+ if error_response:
67
+ return error_response
68
+ except Exception as e:
69
+ return jsonify({"error": f"Patient name validation failed: {str(e)}"}), 500
70
+ try:
71
+ summary = summarizer_pipeline(extracted_text, max_length=350, min_length=50, do_sample=False)[0]["summary_text"]
72
+ except Exception as e:
73
+ summary = "Summary failed"
74
+ extracted_data.append({
75
+ "file": filename,
76
+ "extracted_text": extracted_text,
77
+ "summary": summary,
78
+ "message": "Successful"
79
+ })
80
+ if not extracted_data:
81
+ return jsonify({"error": "No valid medical files processed"}), 400
82
+ return jsonify({"extracted_data": extracted_data}), 200
83
+
84
+ @app.route("/get_updated_medical_data", methods=["GET"])
85
+ def get_updated_data():
86
+ file_name = request.args.get('file')
87
+ if not file_name:
88
+ return jsonify({"error": "File name is required"}), 400
89
+ file_name = file_name.rsplit(".", 1)[0]
90
+ updated_data = get_data_from_storage(file_name)
91
+ if updated_data:
92
+ return jsonify({"file": file_name, "data": updated_data}), 200
93
+ else:
94
+ return jsonify({"error": f"File '{file_name}' not found"}), 404
95
+
96
+ @app.route("/update_medical_data", methods=["PUT"])
97
+ def update_medical_data():
98
+ try:
99
+ data = request.json
100
+ filename = data.get("file")
101
+ filename = filename.rsplit(".", 1)[0]
102
+ updates = data.get("updates", [])
103
+ if not filename or not updates:
104
+ return jsonify({"error": "File name or updates missing"}), 400
105
+ existing_data = get_data_from_storage(filename)
106
+ if not existing_data:
107
+ return jsonify({"error": f"File '{filename}' not found"}), 404
108
+ for update in updates:
109
+ category = update.get("category")
110
+ field = update.get("field")
111
+ new_value = update.get("value")
112
+ updated = False
113
+ for cat in existing_data.get("extracted_data", []):
114
+ for categorized_data in cat.get("categorized_data", []):
115
+ if categorized_data.get("name") == category:
116
+ for fld in categorized_data.get("fields", []):
117
+ if fld.get("label") == field:
118
+ fld["value"] = new_value
119
+ updated = True
120
+ break
121
+ if updated:
122
+ break
123
+ if updated:
124
+ break
125
+ save_data_to_storage(filename, existing_data)
126
+ return jsonify({"message": "Data updated successfully", "updated_data": existing_data}), 200
127
+ except Exception as e:
128
+ return jsonify({"error": str(e)}), 500
129
+
130
+ @app.route("/")
131
+ def home():
132
+ return "Medical Data Extraction API is running!"
ai_med_extract/app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from flask import Flask, request, jsonify, abort
4
+ from flask_cors import CORS
5
+ from werkzeug.utils import secure_filename
6
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
7
+ import whisper
8
+ from dotenv import load_dotenv
9
+ from .agents.text_extractor import TextExtractorAgent
10
+ from .agents.phi_scrubber import PHIScrubberAgent
11
+ from .agents.summarizer import SummarizerAgent
12
+ from .agents.medical_data_extractor import MedicalDataExtractorAgent
13
+ from .utils.file_utils import allowed_file, check_file_size, save_data_to_storage, get_data_from_storage
14
+ from .utils.validation import clean_result, validate_patient_name
15
+
16
+ # Load environment variables
17
+ load_dotenv()
18
+
19
+ app = Flask(__name__)
20
+ CORS(app)
21
+
22
+ UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads'))
23
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
24
+ app.config['UPLOAD_FOLDER'] = UPLOAD_DIR
25
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max file size
26
+
27
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
28
+
29
+ # Model loaders (example, adjust as needed)
30
+ medalpaca_model_loader = None # TODO: Implement LazyModelLoader if needed
31
+ summarization_model_loader = None # TODO: Implement LazyModelLoader if needed
32
+ whisper_model = whisper.load_model("tiny")
33
+
34
+ # Initialize agents
35
+ text_extractor_agent = TextExtractorAgent()
36
+ phi_scrubber_agent = PHIScrubberAgent()
37
+ medical_data_extractor_agent = MedicalDataExtractorAgent(medalpaca_model_loader)
38
+ summarizer_agent = SummarizerAgent(summarization_model_loader)
39
+
40
+ from .api import routes # Import routes to register endpoints
41
+
42
+ if __name__ == "__main__":
43
+ app.run(host="0.0.0.0", port=5000, debug=True)
ai_med_extract/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ai_med_extract/utils/__init__.py
ai_med_extract/utils/file_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import logging
5
+ from werkzeug.utils import secure_filename
6
+ from flask import current_app
7
+
8
+ ALLOWED_EXTENSIONS = {"pdf", "jpg", "jpeg", "png", "svg", "docx", "doc", "xlsx", "xls"}
9
+ MAX_SIZE_PDF_DOCS = 1 * 1024 * 1024 * 1024 # 1GB
10
+ MAX_SIZE_IMAGES = 500 * 1024 * 1024 # 500MB
11
+
12
+
13
+ def allowed_file(filename):
14
+ return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
15
+
16
+
17
+ def check_file_size(file):
18
+ file.seek(0, os.SEEK_END)
19
+ size = file.tell()
20
+ file.seek(0)
21
+ extension = file.filename.rsplit('.', 1)[-1].lower()
22
+ if extension in {"pdf", "docx"} and size > MAX_SIZE_PDF_DOCS:
23
+ return False, f"File {file.filename} exceeds 1GB size limit"
24
+ elif extension in {"jpg", "jpeg", "png"} and size > MAX_SIZE_IMAGES:
25
+ return False, f"Image {file.filename} exceeds 500MB size limit"
26
+ return True, None
27
+
28
+
29
+ def save_data_to_storage(filename, data):
30
+ try:
31
+ upload_folder = current_app.config.get("UPLOAD_FOLDER", "uploads")
32
+ if not os.path.exists(upload_folder):
33
+ os.makedirs(upload_folder, exist_ok=True)
34
+ filename = filename.rsplit(".", 1)[0]
35
+ filepath = os.path.join(upload_folder, f"{filename}.json")
36
+ with open(filepath, "w") as file:
37
+ json.dump(data, file)
38
+ except Exception as e:
39
+ logging.error(f"Exception during save: {e}")
40
+
41
+
42
+ def get_data_from_storage(filename):
43
+ try:
44
+ upload_folder = current_app.config.get("UPLOAD_FOLDER", "uploads")
45
+ filepath = os.path.join(upload_folder, f"{filename}.json")
46
+ if not os.path.exists(filepath):
47
+ return None
48
+ with open(filepath, "r") as file:
49
+ data = json.load(file)
50
+ return data
51
+ except Exception as e:
52
+ logging.error(f"Error loading data: {e}")
53
+ return None
ai_med_extract/utils/validation.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from flask import jsonify
3
+
4
+ def clean_result(value):
5
+ value = re.sub(r"\s+", " ", value)
6
+ value = re.sub(r"[-_:]+", " ", value)
7
+ value = re.sub(r"[^\x00-\x7F]+", " ", value)
8
+ return value if value else "Not Available"
9
+
10
+ def normalize_name(name):
11
+ if not name:
12
+ return ""
13
+ name = name.lower().strip()
14
+ name = re.sub(r"[^\w\s]", "", name)
15
+ name = re.sub(r"^\b\w{1,5}\b\s+", "", name)
16
+ return name
17
+
18
+ def validate_patient_name(extracted_text, patient_name, filename, qa_pipeline):
19
+ detected_name = extract_patient_name(extracted_text, qa_pipeline)
20
+ if not detected_name:
21
+ return jsonify({"error": f"Could not determine patient name from {filename}"}), 400
22
+ normalized_detected_name = normalize_name(detected_name)
23
+ normalized_patient_name = normalize_name(patient_name)
24
+ if normalized_detected_name not in normalized_patient_name:
25
+ return jsonify({
26
+ "error": f"Document '{filename}' does not belong to {patient_name}. Found: {detected_name}"
27
+ }), 400
28
+ return None
29
+
30
+ def extract_patient_name(text, qa_pipeline):
31
+ if not text or not qa_pipeline:
32
+ return None
33
+ try:
34
+ result = qa_pipeline(
35
+ question="What is the patient's name?",
36
+ context=text
37
+ )
38
+ return result.get("answer", "").strip()
39
+ except Exception as e:
40
+ return None
combined1.py ADDED
@@ -0,0 +1,880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import logging
5
+ from dotenv import load_dotenv
6
+ from flask import Flask, request, jsonify, abort
7
+ from werkzeug.utils import secure_filename
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
9
+ import pytesseract
10
+ import cv2
11
+ import pdfplumber
12
+ import pandas as pd
13
+ from PIL import Image
14
+ from docx import Document
15
+ from flask_cors import CORS
16
+ from flask_executor import Executor
17
+ from sentence_transformers import SentenceTransformer
18
+ import faiss
19
+ import whisper
20
+ from PyPDF2 import PdfReader
21
+ from pdf2image import convert_from_path
22
+ from concurrent.futures import ThreadPoolExecutor
23
+ import tempfile
24
+ import tensorflow.keras.layers as KL # Instead of keras.layers as KL
25
+ import numpy as np
26
+
27
+ # Load environment variables
28
+ load_dotenv()
29
+
30
+ # Set Tesseract OCR Path
31
+ pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
32
+
33
+ # Initialize Flask app
34
+ app = Flask(__name__)
35
+ CORS(app)
36
+
37
+ # Configure logging
38
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
39
+
40
+ # Configure upload directory and max file size
41
+ UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads'))
42
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
43
+ app.config['UPLOAD_FOLDER'] = UPLOAD_DIR
44
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max file size
45
+
46
+ # Initialize Flask-Executor for asynchronous tasks
47
+ executor = Executor(app)
48
+ whisper_model = whisper.load_model("tiny")
49
+
50
+ # Allowed file extensions
51
+ ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'flac'}
52
+ ALLOWED_DOCUMENT_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'docx', 'xlsx', 'xls'}
53
+
54
+ UPLOAD_FOLDER = 'Uploads'
55
+ ALLOWED_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'svg', 'docx', 'doc'}
56
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
57
+
58
+ # Set file size limits
59
+ MAX_SIZE_PDF_DOCS = 1 * 1024 * 1024 * 1024 # 1GB
60
+ MAX_SIZE_IMAGES = 500 * 1024 * 1024 # 500MB
61
+
62
+ # Lazy model loading to save resources
63
+ class LazyModelLoader:
64
+ def __init__(self, model_name, task, tokenizer=None):
65
+ self.model_name = model_name
66
+ self.task = task
67
+ self.tokenizer = tokenizer
68
+ self._model = None
69
+
70
+ def load(self):
71
+ """Load the model if not already loaded."""
72
+ if self._model is None:
73
+ logging.info(f"Loading model: {self.model_name}")
74
+ if self.task == "text-generation":
75
+ self._model = AutoModelForCausalLM.from_pretrained(
76
+ self.model_name, device_map="auto", torch_dtype="auto"
77
+ )
78
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, legacy=False)
79
+ if self._model.generation_config.pad_token_id is None or self._model.generation_config.pad_token_id < 0:
80
+ if self._tokenizer.eos_token_id is not None:
81
+ self._model.generation_config.pad_token_id = self._tokenizer.eos_token_id
82
+ logging.info(f"Set pad_token_id to {self._tokenizer.eos_token_id}")
83
+ else:
84
+ logging.warning("No valid eos_token_id found. Setting pad_token_id to 0 as a fallback.")
85
+ self._model.generation_config.pad_token_id = 0
86
+ else:
87
+ self._model = pipeline(self.task, model=self.model_name, tokenizer=self.tokenizer)
88
+ return self._model
89
+
90
+ # Text extraction agents
91
+ class TextExtractorAgent:
92
+ @staticmethod
93
+ def extract_text(filepath, ext):
94
+ """Extract text based on file type."""
95
+ try:
96
+ if ext == "pdf":
97
+ return TextExtractorAgent.extract_text_from_pdf(filepath)
98
+ elif ext in {"jpg", "jpeg", "png"}:
99
+ return TextExtractorAgent.extract_text_from_image(filepath)
100
+ elif ext == "docx":
101
+ return TextExtractorAgent.extract_text_from_docx(filepath)
102
+ elif ext in {"xlsx", "xls"}:
103
+ return TextExtractorAgent.extract_text_from_excel(filepath)
104
+ return None
105
+ except Exception as e:
106
+ logging.error(f"Text extraction failed: {e}")
107
+ return None
108
+
109
+ @staticmethod
110
+ def extract_text_from_pdf(filepath):
111
+ """Extract text from a PDF file."""
112
+ text = ""
113
+ with pdfplumber.open(filepath) as pdf:
114
+ for page in pdf.pages:
115
+ page_text = page.extract_text()
116
+ if page_text:
117
+ text += page_text + "\n"
118
+ return text.strip() or None
119
+
120
+ @staticmethod
121
+ def extract_text_from_image(filepath):
122
+ """Extract text from an image using OCR."""
123
+ image = cv2.imread(filepath)
124
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
125
+ _, processed = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
126
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
127
+ processed_path = temp_file.name
128
+ cv2.imwrite(processed_path, processed)
129
+ text = pytesseract.image_to_string(Image.open(processed_path), lang='eng')
130
+ os.remove(processed_path)
131
+ return text.strip() or None
132
+
133
+ @staticmethod
134
+ def extract_text_from_docx(filepath):
135
+ """Extract text from a DOCX file."""
136
+ doc = Document(filepath)
137
+ text = "\n".join([para.text for para in doc.paragraphs])
138
+ return text.strip() or None
139
+
140
+ @staticmethod
141
+ def extract_text_from_excel(filepath):
142
+ """Extract text from an Excel file."""
143
+ dfs = pd.read_excel(filepath, sheet_name=None)
144
+ text = "\n".join([
145
+ "\n".join([
146
+ " ".join(map(str, df[col].dropna()))
147
+ for col in df.columns
148
+ ])
149
+ for df in dfs.values()
150
+ ])
151
+ return text.strip() or None
152
+
153
+ # PHI scrubbing agent
154
+ class PHIScrubberAgent:
155
+ @staticmethod
156
+ def scrub_phi(text):
157
+ """Remove sensitive personal health information (PHI)."""
158
+ try:
159
+ text = re.sub(r'\b(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b', '[PHONE]', text)
160
+ text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w{2,4}\b', '[EMAIL]', text)
161
+ text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text)
162
+ text = re.sub(r'\b\d{1,5}\s+\w+\s+(Street|St|Avenue|Ave|Boulevard|Blvd|Road|Rd|Lane|Ln)\b', '[ADDRESS]', text, flags=re.IGNORECASE)
163
+ text = re.sub(r'\bDr\.?\s+[A-Z][a-z]+\s+[A-Z][a-z]+\b', 'Dr. [NAME]', text)
164
+ text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text)
165
+ except Exception as e:
166
+ logging.error(f"PHI scrubbing failed: {e}")
167
+ return text
168
+
169
+ # Summarization agent
170
+ class SummarizerAgent:
171
+ def __init__(self, summarization_model_loader):
172
+ self.summarization_model_loader = summarization_model_loader
173
+
174
+ def generate_summary(self, text):
175
+ """Generate a summary of the provided text."""
176
+ model = self.summarization_model_loader.load()
177
+ try:
178
+ summary_result = model(text, do_sample=False)
179
+ return summary_result[0]['summary_text'].strip()
180
+ except Exception as e:
181
+ logging.error(f"Summary generation failed: {e}")
182
+ return "Summary generation failed."
183
+
184
+ def allowed_file(filename, allowed_extensions=ALLOWED_EXTENSIONS):
185
+ """Check if the file extension is allowed."""
186
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions
187
+
188
+ # Knowledge Base
189
+ class KnowledgeBase:
190
+ def __init__(self, documents):
191
+ self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
192
+ self.documents = documents
193
+ self.embeddings = self.embedding_model.encode(documents)
194
+ self.dimension = self.embedding_model.get_sentence_embedding_dimension()
195
+ self.index = faiss.IndexFlatL2(self.dimension)
196
+ self.index.add(self.embeddings)
197
+
198
+ def retrieve_relevant_info(self, query, top_k=3):
199
+ """Retrieve relevant medical information from the knowledge base."""
200
+ query_embedding = self.embedding_model.encode([query])
201
+ distances, indices = self.index.search(query_embedding, top_k)
202
+ relevant_texts = [self.documents[i] for i in indices[0]]
203
+ return relevant_texts
204
+
205
+ # Medical data extraction agent
206
+ class MedicalDataExtractorAgent:
207
+ def __init__(self, model_loader, knowledge_base):
208
+ self.model_loader = model_loader
209
+ self.knowledge_base = knowledge_base
210
+
211
+ def retrieve_relevant_info(self, query, top_k=3):
212
+ """Retrieve relevant medical information from the knowledge base."""
213
+ query_embedding = self.knowledge_base.embedding_model.encode([query])
214
+ distances, indices = self.knowledge_base.index.search(query_embedding, top_k)
215
+ relevant_texts = [self.knowledge_base.documents[i] for i in indices[0]]
216
+ return relevant_texts
217
+
218
+ def extract_medical_data(self, text):
219
+ """Extract structured medical data from text using Agentic RAG."""
220
+ try:
221
+ default_schema = {
222
+ "patient_name": "[NAME]",
223
+ "age": None,
224
+ "gender": None,
225
+ "diagnosis": [],
226
+ "symptoms": [],
227
+ "medications": [],
228
+ "allergies": [],
229
+ "vitals": {
230
+ "blood_pressure": None,
231
+ "heart_rate": None,
232
+ "temperature": None
233
+ },
234
+ "notes": ""
235
+ }
236
+ prompt = f"""
237
+ ### Instruction:
238
+ Extract structured medical data from the following text as a JSON whose parameters are enclosed in "" and without any \.
239
+ The JSON should include patientname, age, gender, medications, allergies, diagnosis, symptoms, vitals, and notes.
240
+ ### Text:
241
+ {text}
242
+ ### Response:
243
+ """
244
+ model = self.model_loader.load()
245
+ tokenizer = self.model_loader._tokenizer
246
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
247
+ outputs = model.generate(
248
+ inputs.input_ids,
249
+ num_return_sequences=1,
250
+ temperature=0.7,
251
+ top_p=0.9,
252
+ do_sample=True
253
+ )
254
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
255
+ logging.info(f"Model response: {response}")
256
+ json_start = response.find("{")
257
+ json_end = response.rfind("}") + 1
258
+ if json_start == -1 or json_end == -1:
259
+ raise ValueError("No JSON found in the model response.")
260
+ structured_data = json.loads(response[json_start:json_end])
261
+ normalized_data = self.normalize_json_output(structured_data, default_schema)
262
+ if normalized_data["vitals"]["blood_pressure"] and isinstance(normalized_data["vitals"]["blood_pressure"], str):
263
+ normalized_data["vitals"]["blood_pressure"] = normalized_data["vitals"]["blood_pressure"].strip('"')
264
+ return json.dumps(normalized_data)
265
+ except json.JSONDecodeError as e:
266
+ logging.error(f"JSON parsing failed: {e}")
267
+ return json.dumps({"error": f"Failed to parse JSON: {str(e)}"})
268
+ except Exception as e:
269
+ logging.error(f"Error extracting medical data: {e}")
270
+ return json.dumps({"error": f"Failed to extract medical data: {str(e)}"})
271
+
272
+ @staticmethod
273
+ def normalize_json_output(model_output, default_schema):
274
+ """Normalize the model's JSON output to match the default schema."""
275
+ try:
276
+ normalized_output = default_schema.copy()
277
+ for key in normalized_output:
278
+ if key in model_output:
279
+ normalized_output[key] = model_output[key]
280
+ return normalized_output
281
+ except Exception as e:
282
+ logging.error(f"Failed to normalize JSON: {e}")
283
+ return default_schema
284
+
285
+ # Initialize lazy loaders
286
+ medalpaca_model_loader = LazyModelLoader(
287
+ model_name="stanford-crfm/BioMedLM",
288
+ task="text-generation"
289
+ )
290
+ summarization_model_loader = LazyModelLoader("google-t5/t5-small", "summarization")
291
+
292
+ # Initialize knowledge base
293
+ medical_documents = [
294
+ "Hypertension is a chronic condition characterized by elevated blood pressure.",
295
+ "Diabetes is a metabolic disorder that affects blood sugar levels.",
296
+ "Common symptoms of chest pain include pressure, tightness, or discomfort in the chest."
297
+ ]
298
+ knowledge_base = KnowledgeBase(medical_documents)
299
+
300
+ # Initialize agents
301
+ text_extractor_agent = TextExtractorAgent()
302
+ phi_scrubber_agent = PHIScrubberAgent()
303
+ medical_data_extractor_agent = MedicalDataExtractorAgent(medalpaca_model_loader, knowledge_base)
304
+ summarizer_agent = SummarizerAgent(summarization_model_loader)
305
+
306
+ # NER to Detect medical info
307
+ CONFIDENCE_THRESHOLD = 0.80
308
+
309
+ def extract_medical_entities(text, ner_pipeline):
310
+ if not text or not text.strip():
311
+ return ["No medical entities found"]
312
+ if ner_pipeline is None:
313
+ print("⚠️ NER model is not loaded, skipping entity extraction.")
314
+ return ["No medical entities found"]
315
+
316
+ ner_results = ner_pipeline(text)
317
+ relevant_entities = {
318
+ "Disease", "MedicalCondition", "Symptom", "Sign_or_Symptom",
319
+ "B-DISEASE", "I-DISEASE",
320
+ "Test", "Measurement", "B-TEST", "I-TEST", "Lab_value", "B-Lab_value", "I-Lab_value",
321
+ "Medication", "B-MEDICATION", "I-MEDICATION", "Treatment",
322
+ "Procedure", "B-Diagnostic_procedure", "I-Diagnostic_procedure",
323
+ "Anatomical_site", "Body_Part", "Organ_or_Tissue",
324
+ "Diagnostic_procedure", "Surgical_Procedure", "Therapeutic_Procedure",
325
+ "Health_condition", "B-Health_condition", "I-Health_condition",
326
+ "Pathological_Condition", "Clinical_Event",
327
+ "Chemical_Substance", "B-Chemical_Substance", "I-Chemical_Substance",
328
+ "Biological_Entity", "B-Biological_Entity", "I-Biological_Entity"
329
+ }
330
+
331
+ medical_entities = set()
332
+ for ent in ner_results:
333
+ entity_label = ent.get("entity_group") or ent.get("entity")
334
+ if entity_label in relevant_entities and ent["score"] >= CONFIDENCE_THRESHOLD:
335
+ word = ent["word"].lower().strip().replace("-", "")
336
+ if len(word) > 2:
337
+ medical_entities.add(word)
338
+
339
+ if len(medical_entities) >= 5:
340
+ return list(medical_entities)
341
+
342
+ return ["No medical entities found"]
343
+
344
+ # Validation: Check File Size
345
+ def check_file_size(file):
346
+ file.seek(0, os.SEEK_END)
347
+ size = file.tell()
348
+ file.seek(0)
349
+ extension = file.filename.rsplit('.', 1)[-1].lower()
350
+ if extension in {'pdf', 'docx'} and size > MAX_SIZE_PDF_DOCS:
351
+ return False, f"File {file.filename} exceeds 1GB size limit"
352
+ elif extension in {'jpg', 'jpeg', 'png'} and size > MAX_SIZE_IMAGES:
353
+ return False, f"Image {file.filename} exceeds 500MB size limit"
354
+ return True, None
355
+
356
+ def extract_patient_name(text, qa_pipeline):
357
+ """Extracts patient name using the given QA pipeline."""
358
+ if not text or not qa_pipeline:
359
+ return None
360
+ try:
361
+ result = qa_pipeline(
362
+ question="What is the patient's name?",
363
+ context=text
364
+ )
365
+ return result.get("answer", "").strip()
366
+ except Exception as e:
367
+ print(f"⚠️ Error extracting patient name: {e}")
368
+ return None
369
+
370
+ def normalize_name(name):
371
+ """Cleans and normalizes names for comparison, removing salutations dynamically."""
372
+ if not name:
373
+ return ""
374
+ name = name.lower().strip()
375
+ name = re.sub(r"[^\w\s]", "", name)
376
+ name = re.sub(r"^\b\w{1,5}\b\s+", "", name)
377
+ return name
378
+
379
+ def validate_patient_name(extracted_text, patient_name, filename, qa_pipeline):
380
+ """Validates if the extracted name matches the registered patient name."""
381
+ detected_name = extract_patient_name(extracted_text, qa_pipeline)
382
+ if not detected_name:
383
+ return jsonify({"error": f"Could not determine patient name from {filename}"}), 400
384
+ normalized_detected_name = normalize_name(detected_name)
385
+ normalized_patient_name = normalize_name(patient_name)
386
+ if normalized_detected_name not in normalized_patient_name:
387
+ return jsonify({
388
+ "error": f"Document '{filename}' does not belong to {patient_name}. Found: {detected_name}"
389
+ }), 400
390
+ return None
391
+
392
+ def is_blurred(image_path, variance_threshold=150):
393
+ try:
394
+ image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
395
+ if image is None:
396
+ print(f"❌ Error: Unable to read image {image_path}")
397
+ return True
398
+ laplacian_var = cv2.Laplacian(image, cv2.CV_64F).var()
399
+ print(f"🔍 Blur Check: Variance={laplacian_var} (Threshold={variance_threshold})")
400
+ edges = cv2.Canny(image, 50, 150)
401
+ edge_density = np.mean(edges)
402
+ print(f"📏 Edge Density: {edge_density}")
403
+ return laplacian_var < variance_threshold and edge_density < 10
404
+ except Exception as e:
405
+ print(f"❌ Error detecting blur: {e}")
406
+ return True
407
+
408
+ def extract_text_from_image(filepath):
409
+ try:
410
+ if is_blurred(filepath):
411
+ return "Image is too blurry, OCR failed."
412
+ image = cv2.imread(filepath)
413
+ if image is None:
414
+ return "Image could not be read."
415
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
416
+ gray = cv2.GaussianBlur(gray, (5, 5), 0)
417
+ gray = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
418
+ cv2.THRESH_BINARY, 11, 2)
419
+ kernel = np.ones((2,2), np.uint8)
420
+ gray = cv2.dilate(gray, kernel, iterations=1)
421
+ processed_path = f"{filepath}_processed.png"
422
+ cv2.imwrite(processed_path, gray)
423
+ text = pytesseract.image_to_string(Image.open(processed_path), lang='eng').strip()
424
+ words = text.split()
425
+ if len(words) < 5:
426
+ return "OCR failed to extract meaningful text."
427
+ return text
428
+ except Exception as e:
429
+ print(f"❌ Error processing {filepath}: {e}")
430
+ return "Failed to extract text"
431
+
432
+ def extract_text_from_pdf(filepath, password=None):
433
+ """Extract text from PDFs using pdfplumber (faster) or OCR (if needed)."""
434
+ text = ""
435
+ try:
436
+ reader = PdfReader(filepath)
437
+ if reader.is_encrypted:
438
+ if not password:
439
+ print("🔒 PDF is encrypted but no password was provided.")
440
+ return {"error": "File is password-protected. Please provide a password."}, 401
441
+ decryption_result = reader.decrypt(password)
442
+ if decryption_result == 0:
443
+ print("❌ Incorrect password provided!")
444
+ return {"error": "Invalid password provided."}, 403
445
+ else:
446
+ print("✅ PDF successfully decrypted!")
447
+ text = "\n".join([page.extract_text() or "" for page in reader.pages])
448
+ if text.strip():
449
+ return text.strip(), 200
450
+ with pdfplumber.open(filepath) as pdf:
451
+ for page in pdf.pages:
452
+ page_text = page.extract_text()
453
+ if page_text:
454
+ text += page_text + "\n"
455
+ if text.strip():
456
+ return text.strip(), 200
457
+ images = convert_from_path(filepath)
458
+ with ThreadPoolExecutor(max_workers=5) as pool:
459
+ ocr_text = list(pool.map(lambda img: pytesseract.image_to_string(img, lang='eng'), images))
460
+ return ("\n".join(ocr_text).strip(), 200) if ocr_text else ("No text found", 415)
461
+ except Exception as e:
462
+ print(f"❌ Error processing PDF {filepath}: {e}")
463
+ return "Failed to extract text"
464
+
465
+ def extract_text_from_docx(filepath):
466
+ doc = Document(filepath)
467
+ text = "\n".join([para.text for para in doc.paragraphs])
468
+ return text.strip() or None
469
+
470
+ def clean_result(value):
471
+ value = re.sub(r"\s+", " ", value)
472
+ value = re.sub(r"[-_:]+", " ", value)
473
+ value = re.sub(r"[^\x00-\x7F]+", " ", value)
474
+ return value if value else "Not Available"
475
+
476
+ def mask_sensitive_info(text):
477
+ text = re.sub(r'(?<=\b\w{2})\w+(?=\s\w{2,})', '***', text)
478
+ text = re.sub(r'\b(\d{2})\d{2}-(\d{2})\d{2}-(\d{2})\d{2}\b', r'**\2-**\3-**', text)
479
+ text = re.sub(r'\b(\d{8})(\d{2})\b', r'********\2', text)
480
+ return text
481
+
482
+ # API Endpoints
483
+ @app.route('/extract_medical_data', methods=['POST'])
484
+ def extract_medical_data():
485
+ """Extract structured medical data from raw text."""
486
+ try:
487
+ data = request.json
488
+ if "text" not in data or not data["text"].strip():
489
+ return jsonify({"error": "No valid text provided"}), 400
490
+ raw_text = data["text"]
491
+ clean_text = phi_scrubber_agent.scrub_phi(raw_text)
492
+ structured_data = medical_data_extractor_agent.extract_medical_data(clean_text)
493
+ return jsonify(json.loads(structured_data)), 200
494
+ except Exception as e:
495
+ logging.error(f"Failed to extract medical data: {e}")
496
+ return jsonify({"error": f"Extraction Error: {str(e)}"}), 500
497
+
498
+ @app.route('/api/transcribe', methods=['POST'])
499
+ def transcribe_audio():
500
+ """Transcribe audio files into text."""
501
+ if 'audio' not in request.files:
502
+ abort(400, description="No audio file provided")
503
+ audio_file = request.files['audio']
504
+ if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
505
+ abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
506
+ filename = secure_filename(audio_file.filename)
507
+ audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
508
+ audio_file.save(audio_path)
509
+ try:
510
+ result = whisper_model.transcribe(audio_path)
511
+ transcribed_text = result["text"]
512
+ os.remove(audio_path)
513
+ return jsonify({"transcribed_text": transcribed_text}), 200
514
+ except Exception as e:
515
+ logging.error(f"Transcription failed: {str(e)}")
516
+ return jsonify({"error": f"Transcription failed: {str(e)}"}), 500
517
+
518
+ @app.route('/api/generate_summary', methods=['POST'])
519
+ def generate_summary():
520
+ """Generate a summary from the provided text."""
521
+ data = request.json
522
+ if "text" not in data or not data["text"].strip():
523
+ return jsonify({"error": "No valid text provided"}), 400
524
+ context = data["text"]
525
+ clean_text = phi_scrubber_agent.scrub_phi(context)
526
+ summary = summarizer_agent.generate_summary(clean_text)
527
+ return jsonify({"summary": summary}), 200
528
+
529
+ @app.route('/api/extract_medical_data_from_audio', methods=['POST'])
530
+ def extract_medical_data_from_audio():
531
+ """Extract medical data from transcribed audio."""
532
+ if 'audio' not in request.files:
533
+ abort(400, description="No audio file provided")
534
+ audio_file = request.files['audio']
535
+ if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
536
+ abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
537
+ logging.info(audio_file.filename)
538
+ logging.info(app.config['UPLOAD_FOLDER'])
539
+ filename = secure_filename(audio_file.filename)
540
+ logging.info(filename)
541
+ audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
542
+ logging.info(audio_path)
543
+ audio_file.save(audio_path)
544
+ try:
545
+ result = whisper_model.transcribe(audio_path)
546
+ transcribed_text = result["text"]
547
+ clean_text = phi_scrubber_agent.scrub_phi(transcribed_text)
548
+ summary = summarizer_agent.generate_summary(transcribed_text)
549
+ structured_data = medical_data_extractor_agent.extract_medical_data(transcribed_text)
550
+ response = {
551
+ "transcribed_text": transcribed_text,
552
+ "summary": summary,
553
+ "medical_chart": json.loads(structured_data)
554
+ }
555
+ os.remove(audio_path)
556
+ return jsonify(response), 200
557
+ except Exception as e:
558
+ logging.error(f"Processing failed: {str(e)}")
559
+ return jsonify({"error": f"Processing failed: {str(e)}"}), 500
560
+
561
+ @app.route('/upload', methods=['POST'])
562
+ def upload_file():
563
+ files = request.files.getlist("file")
564
+ patient_name = request.form.get("patient_name", "").strip()
565
+ password = request.form.get("password")
566
+ qa_model_name = request.form.get("qa_model_name")
567
+ qa_model_type = request.form.get("qa_model_type")
568
+ ner_model_name = request.form.get("ner_model_name")
569
+ ner_model_type = request.form.get("ner_model_type")
570
+ summarizer_model_name = request.form.get("summarizer_model_name")
571
+ summarizer_model_type = request.form.get("summarizer_model_type")
572
+
573
+ if not files:
574
+ return jsonify({"error": "No file uploaded"}), 400
575
+
576
+ try:
577
+ qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name)
578
+ print(f"✅ QA Model Loaded: {qa_model_name}")
579
+ except Exception as e:
580
+ return jsonify({"error": f"QA model load failed: {str(e)}"}), 500
581
+
582
+ try:
583
+ ner_pipeline = pipeline(task=ner_model_type, model=ner_model_name)
584
+ print(f"✅ NER Model Loaded: {ner_model_name}")
585
+ except Exception as e:
586
+ return jsonify({"error": f"NER model load failed: {str(e)}"}), 500
587
+
588
+ try:
589
+ summarizer_pipeline = pipeline(task=summarizer_model_type, model=summarizer_model_name)
590
+ print(f"✅ Summarizer Model Loaded: {summarizer_model_name}")
591
+ except Exception as e:
592
+ return jsonify({"error": f"Summarizer model load failed: {str(e)}"}), 500
593
+
594
+ extracted_data = []
595
+ print(patient_name)
596
+
597
+ for file in files:
598
+ if file.filename == '':
599
+ continue
600
+ if not allowed_file(file.filename):
601
+ return jsonify({"error": f"Unsupported file type: {file.filename}. Supported file types are: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
602
+ if not patient_name:
603
+ return jsonify({"error": "Patient name is missing"}), 400
604
+ valid_size, error_message = check_file_size(file)
605
+ if not valid_size:
606
+ return jsonify({"error": error_message}), 400
607
+
608
+ filename = secure_filename(file.filename)
609
+ filepath = os.path.join(UPLOAD_FOLDER, filename)
610
+ file.save(filepath)
611
+
612
+ extracted_text = None
613
+
614
+ if filename.endswith(".pdf"):
615
+ result = extract_text_from_pdf(filepath, password)
616
+ if isinstance(result, tuple):
617
+ extracted_text, status_code = result
618
+ else:
619
+ extracted_text = result
620
+ status_code = 200
621
+ if isinstance(extracted_text, dict) and "error" in extracted_text:
622
+ return jsonify(extracted_text), status_code
623
+ elif filename.endswith(".docx"):
624
+ extracted_text = extract_text_from_docx(filepath)
625
+ elif filename.endswith((".jpg", ".jpeg", ".png", ".svg")):
626
+ extracted_text = extract_text_from_image(filepath)
627
+
628
+ if not extracted_text or extracted_text == "No text found":
629
+ return jsonify({"error": f"Failed to extract text from {filename}"}), 415
630
+ if extracted_text in ["Image is too blurry, OCR failed.", "OCR failed to extract meaningful text."]:
631
+ return jsonify({"error": f"'{filename}' is too blurry or text is unreadable."}), 422
632
+
633
+ skip_medical_check = request.form.get("skip_medical_check", "false").lower() == "true"
634
+ if not skip_medical_check:
635
+ ner_results = ner_pipeline(extracted_text)
636
+ medical_entities = list(set([r["word"] for r in ner_results if r["entity"].startswith("B-") or r["entity"].startswith("I-")]))
637
+ print(f"Medical entities found: {medical_entities}")
638
+ if not medical_entities:
639
+ return jsonify({"error": f"'{filename}' is not medically relevant"}), 406
640
+ else:
641
+ print(f"Skipping Medical Validation for {filename}")
642
+
643
+ skip_patient_check = request.form.get("skip_patient_check", "false").lower() == "true"
644
+ if not skip_patient_check:
645
+ try:
646
+ error_response = validate_patient_name(extracted_text, patient_name, filename, qa_pipeline)
647
+ if error_response:
648
+ return error_response
649
+ except Exception as e:
650
+ return jsonify({"error": f"Patient name validation failed: {str(e)}"}), 500
651
+ else:
652
+ print(f"Skipping Patient Name Validation for {filename}")
653
+
654
+ try:
655
+ summary = summarizer_pipeline(extracted_text, max_length=350, min_length=50, do_sample=False)[0]["summary_text"]
656
+ except Exception as e:
657
+ summary = "Summary failed"
658
+ print(f"⚠️ Error summarizing: {e}")
659
+
660
+ extracted_data.append({
661
+ "file": filename,
662
+ "extracted_text": extracted_text,
663
+ "summary": summary,
664
+ "message": "Successful"
665
+ })
666
+
667
+ extracted_text = None
668
+ summary = None
669
+
670
+ if not extracted_data:
671
+ return jsonify({"error": "No valid medical files processed"}), 400
672
+
673
+ return jsonify({"extracted_data": extracted_data}), 200
674
+
675
+ @app.route('/extract_medical_data_questions', methods=['POST'])
676
+ def extract_medical_data_questions():
677
+ """Extract medical data based on predefined questions."""
678
+ data = request.json
679
+ qa_model_name = data.get("qa_model_name")
680
+ qa_model_type = data.get("qa_model_type")
681
+ if "extracted_data" not in data:
682
+ return jsonify({"error": "Missing 'extracted_data' in request"}), 400
683
+
684
+ if not qa_model_name or not qa_model_type:
685
+ return jsonify({"error": "Missing 'model_name' or 'model_type'"}), 400
686
+
687
+ try:
688
+ print(f"🌀 Loading model: {qa_model_name} ({qa_model_type})")
689
+ qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name)
690
+ loaded_model_name = qa_pipeline.model.config._name_or_path
691
+ loaded_model_type = qa_pipeline.task
692
+ print(f"✅ Model loaded: {loaded_model_name}")
693
+ except Exception as e:
694
+ print("❌ Error loading model:", str(e))
695
+ return jsonify({"error": f"Could not load model: {str(e)}"}), 500
696
+
697
+ questions = {
698
+ "Patient Name": "What is the patient's name?",
699
+ "Age": "What is the patient's age?",
700
+ "Gender": "What is the patient's gender?",
701
+ "Date of Birth": "What is the patient's date of birth?",
702
+ "Patient ID": "What is the patient ID?",
703
+ "Reason for Visit": "What is the reason for the patient's visit?",
704
+ "Physician": "Who is the physician in charge of the patient?",
705
+ "Test Date": "What is the test date?",
706
+ "Hemoglobin": "What is the patient's hemoglobin level?",
707
+ "Blood Glucose (Fasting)": "What is the patient's fasting blood glucose level?",
708
+ "Total Cholesterol": "What is the total cholesterol level?",
709
+ "LDL Cholesterol": "What is the LDL cholesterol level?",
710
+ "HDL Cholesterol": "What is the HDL cholesterol level?",
711
+ "Serum Creatinine": "What is the serum creatinine level?",
712
+ "Vitamin D (25-OH)": "What is the patient's Vitamin D level?",
713
+ "Height": "What is the patient's height?",
714
+ "Weight": "What is the patient's weight?",
715
+ "Blood Pressure (Systolic)": "What is the patient's systolic blood pressure?",
716
+ "Blood Pressure (Diastolic)": "What is the patient's diastolic blood pressure?",
717
+ "Recommendations": "What are the recommendations based on the test results?"
718
+ }
719
+
720
+ structured_response = {"extracted_data": []}
721
+
722
+ for file_data in data["extracted_data"]:
723
+ filename = file_data["file"]
724
+ context = file_data["extracted_text"]
725
+
726
+ if not context:
727
+ structured_response["extracted_data"].append({
728
+ "file": filename,
729
+ "medical_terms": "No data extracted",
730
+ })
731
+ continue
732
+
733
+ extracted_info = {}
734
+
735
+ for key, question in questions.items():
736
+ try:
737
+ result = qa_pipeline(question=question, context=context)
738
+ extracted_info[key] = clean_result(result.get("answer", "Not Available"))
739
+ except:
740
+ extracted_info[key] = "Error extracting"
741
+
742
+ categorized_data = [
743
+ {
744
+ "name": "Patient Information",
745
+ "fields": [
746
+ {"label": "Patient Name", "value": extracted_info.get("Patient Name", "")},
747
+ {"label": "Date of Birth", "value": extracted_info.get("Date of Birth", "")},
748
+ {"label": "Gender", "value": extracted_info.get("Gender", "")},
749
+ {"label": "Patient ID", "value": extracted_info.get("Patient ID", "")}
750
+ ]
751
+ },
752
+ {
753
+ "name": "Vitals",
754
+ "fields": [
755
+ {"label": "Height", "value": extracted_info.get("Height", "")},
756
+ {"label": "Weight", "value": extracted_info.get("Weight", "")},
757
+ {"label": "Blood Pressure", "value": f"{extracted_info.get('Blood Pressure (Systolic)', '')}/{extracted_info.get('Blood Pressure (Diastolic)', '')} mmHg"},
758
+ {"label": "Hemoglobin", "value": extracted_info.get("Hemoglobin", "")},
759
+ {"label": "Serum Creatinine", "value": extracted_info.get("Serum Creatinine", "")}
760
+ ]
761
+ },
762
+ {
763
+ "name": "Lab Results",
764
+ "fields": [
765
+ {"label": "Blood Glucose (Fasting)", "value": extracted_info.get("Blood Glucose (Fasting)", "")},
766
+ {"label": "Total Cholesterol", "value": extracted_info.get("Total Cholesterol", "")},
767
+ {"label": "LDL Cholesterol", "value": extracted_info.get("LDL Cholesterol", "")},
768
+ {"label": "HDL Cholesterol", "value": extracted_info.get("HDL Cholesterol", "")},
769
+ {"label": "Vitamin D (25-OH)", "value": extracted_info.get("Vitamin D (25-OH)", "")}
770
+ ]
771
+ },
772
+ {
773
+ "name": "Medical Notes",
774
+ "fields": [
775
+ {"label": "Reason for Visit", "value": extracted_info.get("Reason for Visit", "")},
776
+ {"label": "Physician", "value": extracted_info.get("Physician", "")},
777
+ {"label": "Test Date", "value": extracted_info.get("Test Date", "")},
778
+ {"label": "Recommendations", "value": extracted_info.get("Recommendations", "")}
779
+ ]
780
+ }
781
+ ]
782
+ structured_response["extracted_data"].append({
783
+ "file": filename,
784
+ "medical_terms": extracted_info,
785
+ "categorized_data": categorized_data,
786
+ "model_used": loaded_model_name,
787
+ "model_type": loaded_model_type
788
+ })
789
+
790
+ save_data_to_storage(filename, structured_response)
791
+ print(f"✅ Extracted data saved to: {os.path.join(UPLOAD_FOLDER, f'{filename}.json')}")
792
+
793
+ return jsonify(structured_response), 200
794
+
795
+ def get_data_from_storage(filename):
796
+ try:
797
+ filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json")
798
+ print(f"🔍 Looking for file at: {filepath}")
799
+ if not os.path.exists(filepath):
800
+ print(f"🚫 File not found at: {filepath}")
801
+ return None
802
+ with open(filepath, "r") as file:
803
+ data = json.load(file)
804
+ print(f"✅ File found and loaded: {filepath}")
805
+ return data
806
+ except Exception as e:
807
+ print(f"🚨 Error loading data: {e}")
808
+ return None
809
+
810
+ def save_data_to_storage(filename, data):
811
+ try:
812
+ filename = filename.rsplit(".", 1)[0]
813
+ filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json")
814
+ print(f"Saving to: {filepath}")
815
+ print(f"Directory exists: {os.path.exists(UPLOAD_FOLDER)}")
816
+ if not os.path.exists(UPLOAD_FOLDER):
817
+ print(f"Directory not found. Creating: {UPLOAD_FOLDER}")
818
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
819
+ with open(filepath, "w") as file:
820
+ json.dump(data, file)
821
+ print(f"✅ Data saved successfully to {filepath}")
822
+ except Exception as e:
823
+ print(f"🚨 Exception during save: {e}")
824
+
825
+ @app.route('/get_updated_medical_data', methods=['GET'])
826
+ def get_updated_data():
827
+ file_name = request.args.get('file')
828
+ if not file_name:
829
+ return jsonify({"error": "File name is required"}), 400
830
+ file_name = file_name.rsplit(".", 1)[0]
831
+ updated_data = get_data_from_storage(file_name)
832
+ if updated_data:
833
+ return jsonify({"file": file_name, "data": updated_data}), 200
834
+ else:
835
+ return jsonify({"error": f"File '{file_name}' not found"}), 404
836
+
837
+ @app.route('/update_medical_data', methods=['PUT'])
838
+ def update_medical_data():
839
+ try:
840
+ data = request.json
841
+ print("Received data:", data)
842
+ filename = data.get("file")
843
+ filename = filename.rsplit(".", 1)[0]
844
+ updates = data.get("updates", [])
845
+ if not filename or not updates:
846
+ return jsonify({"error": "File name or updates missing"}), 400
847
+ existing_data = get_data_from_storage(filename)
848
+ if not existing_data:
849
+ return jsonify({"error": f"File '{filename}' not found"}), 404
850
+ for update in updates:
851
+ category = update.get("category")
852
+ field = update.get("field")
853
+ new_value = update.get("value")
854
+ updated = False
855
+ for cat in existing_data.get("extracted_data", []):
856
+ for categorized_data in cat.get("categorized_data", []):
857
+ if categorized_data.get("name") == category:
858
+ for fld in categorized_data.get("fields", []):
859
+ if fld.get("label") == field:
860
+ print(f"🔄 Updating {category} -> {field} from '{fld['value']}' to '{new_value}'")
861
+ fld["value"] = new_value
862
+ updated = True
863
+ break
864
+ if updated:
865
+ break
866
+ if updated:
867
+ break
868
+ save_data_to_storage(filename, existing_data)
869
+ print("✅ Updated data:", existing_data)
870
+ return jsonify({"message": "Data updated successfully", "updated_data": existing_data}), 200
871
+ except Exception as e:
872
+ print("❌ Error:", str(e))
873
+ return jsonify({"error": str(e)}), 500
874
+
875
+ @app.route('/')
876
+ def home():
877
+ return "Medical Data Extraction API is running!"
878
+
879
+ if __name__ == '__main__':
880
+ app.run(host='0.0.0.0', port=5000, debug=True)
document_based_extraction.py ADDED
@@ -0,0 +1,1188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, json
2
+ import time, logging, functools
3
+ import pytesseract
4
+ import cv2
5
+ import pdfplumber
6
+ import numpy as np
7
+ from PIL import Image
8
+ from PyPDF2 import PdfReader
9
+ from pdf2image import convert_from_path
10
+ from flask import Flask, request, jsonify
11
+ from flask_cors import CORS
12
+ import torch
13
+ from werkzeug.utils import secure_filename
14
+ from docx import Document
15
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ from collections import defaultdict
18
+ from huggingface_hub import login
19
+
20
+
21
+ # -------------------- Logging Config -------------------- #
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format="%(asctime)s - %(levelname)s - %(message)s",
25
+ handlers=[
26
+ logging.FileHandler("app.log"),
27
+ logging.StreamHandler()
28
+ ]
29
+ )
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # -------------------- Execution Time Decorator -------------------- #
33
+ def log_execution_time(level=logging.INFO):
34
+ def decorator(func):
35
+ @functools.wraps(func)
36
+ def wrapper(*args, **kwargs):
37
+ start_time = time.time()
38
+ try:
39
+ result = func(*args, **kwargs)
40
+ duration = time.time() - start_time
41
+ logger.log(level, f"⏱️ {func.__name__} executed in {duration:.6f} seconds")
42
+ return result
43
+ except Exception as e:
44
+ duration = time.time() - start_time
45
+ logger.exception(f"❌ Exception in {func.__name__} after {duration:.6f} seconds: {e}")
46
+ raise
47
+ return wrapper
48
+ return decorator
49
+
50
+
51
+ login(
52
+ "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"
53
+ ) # 🧠 This will store it and every model load will use it
54
+
55
+ executor = ThreadPoolExecutor(max_workers=5)
56
+ logger.info("Executor initialized with 5 workers")
57
+
58
+ # Set Tesseract OCR Path
59
+ # in Windows
60
+ # pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
61
+ # in Linux
62
+ pytesseract.pytesseract.tesseract_cmd = "/usr/local/bin/tesseract"
63
+
64
+ # Set up Flask app
65
+ app = Flask(__name__)
66
+ CORS(app)
67
+
68
+ UPLOAD_FOLDER = "uploads"
69
+ ALLOWED_EXTENSIONS = {"pdf", "jpg", "jpeg", "png", "svg", "docx", "doc"}
70
+ app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
71
+
72
+ # Set file size limits
73
+ MAX_SIZE_PDF_DOCS = 1 * 1024 * 1024 * 1024 # *1GB*
74
+ MAX_SIZE_IMAGES = 500 * 1024 * 1024 # *500MB*
75
+
76
+
77
+ # # Load ClinicalBERT Model for Classification
78
+ # try:
79
+ # zero_shot_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
80
+ # print("✅ zero_shot_classifier Model Loaded Successfully")
81
+ # except Exception as e:
82
+ # zero_shot_classifier = None
83
+ # print("❌ Error loading ClinicalBERT Model:", str(e))
84
+
85
+
86
+ if not os.path.exists(UPLOAD_FOLDER):
87
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
88
+
89
+ # NER to Detect medical info
90
+ CONFIDENCE_THRESHOLD = 0.80
91
+
92
+
93
+ @log_execution_time()
94
+ def extract_medical_entities(text):
95
+ if not text or not text.strip():
96
+ return ["No medical entities found"]
97
+ if ner_pipeline is None: # type: ignore
98
+ logger.warning("NER model is not loaded, skipping entity extraction.")
99
+ return ["No medical entities found"]
100
+
101
+ ner_results = ner_pipeline(text) # type: ignore
102
+ relevant_entities = {
103
+ # Diseases & Symptoms
104
+ "Disease",
105
+ "MedicalCondition",
106
+ "Symptom",
107
+ "Sign_or_Symptom",
108
+ "B-DISEASE",
109
+ "I-DISEASE",
110
+ # Tests, Measurements, and Lab Values
111
+ "Test",
112
+ "Measurement",
113
+ "B-TEST",
114
+ "I-TEST",
115
+ "Lab_value",
116
+ "B-Lab_value",
117
+ "I-Lab_value",
118
+ # Medications, Treatments, and Procedures
119
+ "Medication",
120
+ "B-MEDICATION",
121
+ "I-MEDICATION",
122
+ "Treatment",
123
+ "Procedure",
124
+ "B-Diagnostic_procedure",
125
+ "I-Diagnostic_procedure",
126
+ # Body Parts & Medical Anatomy
127
+ "Anatomical_site",
128
+ "Body_Part",
129
+ "Organ_or_Tissue",
130
+ # Medical Procedures
131
+ "Diagnostic_procedure",
132
+ "Surgical_Procedure",
133
+ "Therapeutic_Procedure",
134
+ # Clinical Terms
135
+ "Health_condition",
136
+ "B-Health_condition",
137
+ "I-Health_condition",
138
+ "Pathological_Condition",
139
+ "Clinical_Event",
140
+ # Biological & Chemical Substances (Relevant to Lab Reports)
141
+ "Chemical_Substance",
142
+ "B-Chemical_Substance",
143
+ "I-Chemical_Substance",
144
+ "Biological_Entity",
145
+ "B-Biological_Entity",
146
+ "I-Biological_Entity",
147
+ }
148
+
149
+ medical_entities = set()
150
+ for ent in ner_results:
151
+ entity_label = ent.get("entity_group") or ent.get("entity")
152
+ if entity_label in relevant_entities and ent["score"] >= CONFIDENCE_THRESHOLD:
153
+ word = ent["word"].lower().strip().replace("-", "") # Normalize text
154
+ if len(word) > 2: # Ignore short/junk words
155
+ medical_entities.add(word)
156
+
157
+ if len(medical_entities) >= 5:
158
+ logger.info(f"Extracted {len(medical_entities)} medical entities")
159
+ return list(medical_entities)
160
+
161
+ logger.info("Not enough medical entities found")
162
+ return ["No medical entities found"]
163
+
164
+
165
+ # Validation: Check Allowed File Types
166
+ def allowed_file(filename):
167
+ return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
168
+
169
+
170
+ # Validation: Check File Size
171
+ def check_file_size(file):
172
+ file.seek(0, os.SEEK_END)
173
+ size = file.tell()
174
+ file.seek(0)
175
+ extension = file.filename.rsplit(".", 1)[-1].lower()
176
+ logger.info(f"Checking file size for '{file.filename}' - Size: {size} bytes")
177
+ if extension in {"pdf", "docx"} and size > MAX_SIZE_PDF_DOCS:
178
+ logger.warning(f"{file.filename} exceeds 1GB limit")
179
+ return False, f"File {file.filename} exceeds 1MB size limit"
180
+ elif extension in {"jpg", "jpeg", "png"} and size > MAX_SIZE_IMAGES:
181
+ logger.warning(f"{file.filename} exceeds 500MB image limit")
182
+ return False, f"Image {file.filename} exceeds 500KB size limit"
183
+ return True, None
184
+
185
+
186
+ @log_execution_time()
187
+ def extract_patient_name(text, qa_pipeline):
188
+ if not text or not qa_pipeline:
189
+ return None
190
+ try:
191
+ result = qa_pipeline(question="What is the patient's name?", context=text)
192
+ answer = result.get("answer", "").strip()
193
+ logger.info(f"Extracted patient name: {answer}")
194
+ return answer
195
+ except Exception as e:
196
+ logger.error(f"Error extracting patient name: {e}")
197
+ return None
198
+
199
+
200
+ def normalize_name(name):
201
+ """Cleans and normalizes names for comparison, removing salutations dynamically"""
202
+ if not name:
203
+ return ""
204
+ name = name.lower().strip()
205
+ name = re.sub(r"[^\w\s]", "", name)
206
+ name = re.sub(r"^\b\w{1,5}\b\s+", "", name) # Matches short words at the start
207
+ return name
208
+
209
+
210
+ @log_execution_time()
211
+ def validate_patient_name(extracted_text, patient_name, filename, qa_pipeline):
212
+ """Validates if the extracted name matches the registered patient name"""
213
+ detected_name = extract_patient_name(extracted_text, qa_pipeline)
214
+ if not detected_name:
215
+ logger.warning(f"Could not determine patient name from {filename}")
216
+ return (
217
+ jsonify({"error": f"Could not determine patient name from {filename}"}),
218
+ 400,
219
+ )
220
+
221
+ normalized_detected_name = normalize_name(detected_name)
222
+ normalized_patient_name = normalize_name(patient_name)
223
+
224
+ if normalized_detected_name not in normalized_patient_name:
225
+ logger.warning(
226
+ f"Patient mismatch in file '{filename}': Found '{detected_name}'"
227
+ )
228
+ return (
229
+ jsonify(
230
+ {
231
+ "error": f"Document '{filename}' does not belong to {patient_name}. Found: {detected_name}"
232
+ }
233
+ ),
234
+ 400,
235
+ )
236
+ logger.info(f"Patient name validation passed for '{filename}'")
237
+ return None # No error, validation passed
238
+
239
+
240
+ # Check if the image is blurred using the Laplacian method
241
+ def is_blurred(image_path, variance_threshold=150):
242
+ try:
243
+ image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
244
+ if image is None:
245
+ logger.error(f"Unable to read image: {image_path}")
246
+ return True # Assume it's blurry if not readable
247
+
248
+ # Compute Laplacian variance
249
+ laplacian_var = cv2.Laplacian(image, cv2.CV_64F).var()
250
+ logger.info(
251
+ f"Blur Check on '{image_path}': Laplacian Variance = {laplacian_var:.2f} (Threshold = {variance_threshold})"
252
+ )
253
+
254
+ # Compute Edge Density (Additional Check)
255
+ edges = cv2.Canny(image, 50, 150)
256
+ edge_density = np.mean(edges)
257
+ logger.info(f"Edge Density for '{image_path}': {edge_density:.2f}")
258
+
259
+ is_blurry = laplacian_var < variance_threshold and edge_density < 10
260
+ if is_blurry:
261
+ logger.warning(f"Image '{image_path}' flagged as blurry.")
262
+ return is_blurry
263
+ except Exception as e:
264
+ logger.exception(f"Exception during blur detection for '{image_path}': {e}")
265
+ return True # Assume it's blurry on failure
266
+
267
+
268
+ # Helper Function: Extract Text from Images (OCR) with Blur Detection
269
+ @log_execution_time()
270
+ def extract_text_from_image(filepath):
271
+ try:
272
+ # Check if the image is blurry
273
+ if is_blurred(filepath):
274
+ logger.warning(f"OCR skipped: '{filepath}' is too blurry.")
275
+ return "Image is too blurry, OCR failed."
276
+
277
+ image = cv2.imread(filepath)
278
+ if image is None:
279
+ logger.error(f"OCR failed: Unable to read image '{filepath}'.")
280
+ return "Image could not be read."
281
+
282
+ # Convert to Grayscale and Apply Thresholding
283
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
284
+ gray = cv2.GaussianBlur(gray, (5, 5), 0)
285
+ gray = cv2.adaptiveThreshold(
286
+ gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
287
+ )
288
+
289
+ # Apply dilation (bolds the text) for better OCR accuracy
290
+ kernel = np.ones((2, 2), np.uint8)
291
+ gray = cv2.dilate(gray, kernel, iterations=1)
292
+ processed_path = f"{filepath}_processed.png"
293
+ cv2.imwrite(processed_path, gray)
294
+ logger.info(f"Image preprocessed and saved: {processed_path}")
295
+ text = pytesseract.image_to_string(
296
+ Image.open(processed_path), lang="eng"
297
+ ).strip()
298
+ # Validate OCR output (Reject if too little text is extracted)
299
+ word_count = len(text.split())
300
+ logger.info(
301
+ f"OCR completed for '{filepath}' with {word_count} words extracted."
302
+ )
303
+
304
+ if word_count < 5:
305
+ logger.warning(f"OCR output too small for '{filepath}'. Might be junk.")
306
+ return "OCR failed to extract meaningful text."
307
+
308
+ return text
309
+
310
+ except Exception as e:
311
+ logger.exception(f"Error extracting text from image '{filepath}': {e}")
312
+ return "Failed to extract text"
313
+
314
+
315
+ # Helper Function: Extract Text from PDF
316
+ @log_execution_time()
317
+ def extract_text_from_pdf(filepath, password=None):
318
+ """Extract text from PDFs using pdfplumber (faster) or OCR (if needed)."""
319
+ text = ""
320
+
321
+ try:
322
+ logger.info(f"Starting PDF extraction: {filepath}")
323
+ reader = PdfReader(filepath)
324
+
325
+ if reader.is_encrypted:
326
+ if not password:
327
+ logger.warning("Encrypted PDF without password.")
328
+ return {
329
+ "error": "File is password-protected. Please provide a password."
330
+ }, 401
331
+
332
+ # ✅ Attempt to decrypt
333
+ decryption_result = reader.decrypt(password)
334
+ if decryption_result == 0: # Decryption failed
335
+ logger.error("Incorrect password provided.")
336
+ return {"error": "Invalid password provided."}, 403
337
+ else:
338
+ logger.info("PDF decryption successful.")
339
+
340
+ text = "\n".join([page.extract_text() or "" for page in reader.pages])
341
+ if text.strip():
342
+ logger.info("Text extracted from decrypted PDF.")
343
+ return text.strip(), 200
344
+
345
+ # ✅ Now, use pdfplumber for text extraction
346
+ with pdfplumber.open(filepath) as pdf:
347
+ for page in pdf.pages:
348
+ page_text = page.extract_text()
349
+ if page_text:
350
+ text += page_text + "\n"
351
+
352
+ if text.strip():
353
+ logger.info(
354
+ f"PDF text extracted using pdfplumber: {len(text.split())} words."
355
+ )
356
+ return text.strip(), 200 # ✅ Always return a tuple (text, status)
357
+
358
+ logger.info("No text found via pdfplumber. Falling back to OCR.")
359
+ # ✅ Use OCR if the PDF has no selectable text
360
+ images = convert_from_path(filepath)
361
+ with ThreadPoolExecutor(max_workers=5) as pool:
362
+ ocr_text = list(
363
+ pool.map(
364
+ lambda img: pytesseract.image_to_string(img, lang="eng"), images
365
+ )
366
+ )
367
+
368
+ full_ocr_text = "\n".join(ocr_text).strip()
369
+ logger.info(
370
+ f"OCR fallback complete for PDF: {len(full_ocr_text.split())} words extracted."
371
+ )
372
+
373
+ return (full_ocr_text, 200) if full_ocr_text else ("No text found", 415)
374
+
375
+ except Exception as e:
376
+ logger.exception(f"Error during PDF processing: {filepath}")
377
+ return "Failed to extract text"
378
+
379
+
380
+ # Helper Function: Extract Text from DOCX
381
+ @log_execution_time()
382
+ def extract_text_from_docx(filepath):
383
+ try:
384
+ doc = Document(filepath)
385
+ text = "\n".join([para.text for para in doc.paragraphs])
386
+ word_count = len(text.split())
387
+ logger.info(f"DOCX extracted from '{filepath}': {word_count} words.")
388
+ return text.strip() or None
389
+ except Exception as e:
390
+ logger.exception(f"Failed to extract text from DOCX: {filepath}")
391
+ return None
392
+
393
+
394
+ # Masking function to hide sensitive data
395
+ def mask_sensitive_info(text):
396
+ text = re.sub(r"(?<=\b\w{2})\w+(?=\s\w{2,})", "*", text) # Mask names
397
+ text = re.sub(
398
+ r"\b(\d{2})\d{2}-(\d{2})\d{2}-(\d{2})\d{2}\b", r"\2-\3-", text
399
+ ) # Mask DOB
400
+ text = re.sub(r"\b(\d{8})(\d{2})\b", r"\2", text) # Mask phone numbers
401
+ return text
402
+
403
+
404
+ # ------------------Upload Documents ------------------ #
405
+ # API Route: Upload File & Extract Text
406
+ @app.route("/upload", methods=["POST"])
407
+ @log_execution_time()
408
+ def upload_file():
409
+ logger.info("📥 Upload request received")
410
+ files = request.files.getlist("file")
411
+ patient_name = request.form.get("patient_name", "").strip()
412
+ password = request.form.get("password") # Get password if provided
413
+ # Dynamic model info from form
414
+ qa_model_name = request.form.get("qa_model_name")
415
+ qa_model_type = request.form.get("qa_model_type")
416
+
417
+ ner_model_name = request.form.get("ner_model_name")
418
+ ner_model_type = request.form.get("ner_model_type")
419
+
420
+ summarizer_model_name = request.form.get("summarizer_model_name")
421
+ summarizer_model_type = request.form.get("summarizer_model_type")
422
+
423
+ if not files:
424
+ logger.warning("No file uploaded")
425
+ return jsonify({"error": "No file uploaded"}), 400
426
+
427
+ # 🔌 Load models dynamically
428
+ try:
429
+ qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name)
430
+ logger.info(f"✅ QA model loaded: {qa_model_name}")
431
+ except Exception as e:
432
+ logger.error(f"❌ QA model load failed: {e}")
433
+ return jsonify({"error": f"QA model load failed: {str(e)}"}), 500
434
+
435
+ try:
436
+ ner_pipeline = pipeline(task=ner_model_type, model=ner_model_name)
437
+ logger.info(f"✅ NER model loaded: {ner_model_name}")
438
+ except Exception as e:
439
+ logger.error(f"❌ NER model load failed: {e}")
440
+ return jsonify({"error": f"NER model load failed: {str(e)}"}), 500
441
+
442
+ try:
443
+ summarizer_pipeline = pipeline(
444
+ task=summarizer_model_type, model=summarizer_model_name
445
+ )
446
+ logger.info(f"✅ Summarizer model loaded: {summarizer_model_name}")
447
+ except Exception as e:
448
+ logger.error(f"❌ Summarizer model load failed: {e}")
449
+ return jsonify({"error": f"Summarizer model load failed: {str(e)}"}), 500
450
+
451
+ extracted_data = []
452
+ print(patient_name)
453
+
454
+ for file in files:
455
+ logger.info(f"📂 Processing file: {file.filename}")
456
+
457
+ if file.filename == "":
458
+ logger.warning("Skipping unnamed file")
459
+ continue # Skip empty file names
460
+
461
+ if not allowed_file(file.filename):
462
+ logger.warning(f"Unsupported file type: {file.filename}")
463
+ return (
464
+ jsonify(
465
+ {
466
+ "error": f"Unsupported file type: {file.filename}. Supported file types are: {', '.join(ALLOWED_EXTENSIONS)}"
467
+ }
468
+ ),
469
+ 400,
470
+ )
471
+
472
+ if not patient_name:
473
+ logger.warning("Patient name missing")
474
+ return jsonify({"error": "Patient name is missing"}), 400
475
+
476
+ # *Check file size*
477
+ valid_size, error_message = check_file_size(file)
478
+ if not valid_size:
479
+ logger.warning(f"❌ File size validation failed: {error_message}")
480
+ return jsonify({"error": error_message}), 400
481
+
482
+ filename = secure_filename(file.filename)
483
+ filepath = os.path.join(UPLOAD_FOLDER, filename)
484
+ file.save(filepath)
485
+ logger.info(f"✅ File saved: {filepath}")
486
+
487
+ extracted_text = None
488
+
489
+ # ✅ *Extract text based on file type*
490
+ if filename.endswith(".pdf"):
491
+ logger.info("🧾 Extracting text from PDF")
492
+ result = extract_text_from_pdf(filepath, password)
493
+
494
+ # ✅ If PDF requires a password, return 401
495
+ if isinstance(result, tuple):
496
+ extracted_text, status_code = result
497
+ else:
498
+ extracted_text = result
499
+ status_code = 200
500
+
501
+ if isinstance(extracted_text, dict) and "error" in extracted_text:
502
+ logger.warning(f"⚠️ PDF extraction error: {extracted_text}")
503
+ return jsonify(extracted_text), status_code
504
+ elif filename.endswith(".docx"):
505
+ extracted_text = extract_text_from_docx(filepath)
506
+ elif filename.endswith((".jpg", ".jpeg", ".png", ".svg")):
507
+ logger.info("🖼️ Extracting text from image")
508
+ extracted_text = extract_text_from_image(filepath)
509
+
510
+ if not extracted_text or extracted_text == "No text found":
511
+ logger.warning(f"⚠️ No text extracted from {filename}")
512
+ return (
513
+ jsonify({"error": f"Failed to extract text from {filename}"}),
514
+ 415,
515
+ ) # Unsupported Media Type
516
+
517
+ # reject blurred images
518
+ if extracted_text in [
519
+ "Image is too blurry, OCR failed.",
520
+ "OCR failed to extract meaningful text.",
521
+ ]:
522
+ logger.warning(f"🔍 OCR failed or image too blurry: {filename}")
523
+ return (
524
+ jsonify(
525
+ {"error": f"'{filename}' is too blurry or text is unreadable."}
526
+ ),
527
+ 422,
528
+ ) # Unprocessable Entity
529
+
530
+ # ✅ Medical Validation using NER
531
+ skip_medical_check = (
532
+ request.form.get("skip_medical_check", "false").lower() == "true"
533
+ )
534
+ if not skip_medical_check:
535
+ logger.info("🧠 Running NER medical validation")
536
+ start_time = time.time()
537
+ ner_results = ner_pipeline(extracted_text)
538
+ medical_entities = list(
539
+ set(
540
+ [
541
+ r["word"]
542
+ for r in ner_results
543
+ if r["entity"].startswith("B-") or r["entity"].startswith("I-")
544
+ ]
545
+ )
546
+ )
547
+ elapsed_time = time.time() - start_time
548
+ logger.info(f"⏱️ Medical entity validation took {elapsed_time:.2f}s")
549
+
550
+ logger.info(f"🩺 Medical entities found: {medical_entities}")
551
+ if not medical_entities:
552
+ logger.warning(f"❌ No medical relevance in {filename}")
553
+ return (
554
+ jsonify({"error": f"'{filename}' is not medically relevant"}),
555
+ 406,
556
+ )
557
+ else:
558
+ logger.info(f"⏭️ Skipping medical validation for {filename}")
559
+
560
+ # # ✅ Patient Name Validation using QA
561
+ # skip_patient_check = request.form.get("skip_patient_check", "false").lower() == "true"
562
+ # if not skip_patient_check:
563
+ # try:
564
+ # logger.info("🧍 Validating patient name")
565
+ # start_time = time.time()
566
+ # error_response = validate_patient_name(extracted_text, patient_name, filename,qa_pipeline)
567
+ # elapsed_time = time.time() - start_time
568
+ # logger.info(f"⏱️ Patient name validation took {elapsed_time:.2f}s")
569
+
570
+ # if error_response:
571
+ # return error_response
572
+ # except Exception as e:
573
+ # logger.error(f"❌ Patient name validation failed: {e}")
574
+ # return jsonify({"error": f"Patient name validation failed: {str(e)}"}), 500
575
+ # else:
576
+ # logger.info(f"⏭️ Skipping patient name validation for {filename}")
577
+
578
+ # ✨ Generate Summary using Summarizer
579
+ try:
580
+ logger.info("📝 Generating summary: %s", extracted_text)
581
+
582
+ start_time = time.time()
583
+ summary = summarizer_pipeline(
584
+ extracted_text, max_length=350, min_length=50, do_sample=False
585
+ )[0]["summary_text"]
586
+ elapsed_time = time.time() - start_time
587
+
588
+ logger.info(f"✅ Summary generated: {summary}")
589
+ logger.info(f"⏱️ Summary generation took {elapsed_time:.2f} seconds")
590
+ except Exception as e:
591
+ summary = "Summary failed"
592
+ logger.warning(f"⚠ Summary generation failed: {e}")
593
+ # # Classify report type
594
+ # report_type = classify_medical_document(extracted_text)
595
+ # print(report_type)
596
+ # ✅ Summarize extracted text
597
+ extracted_data.append(
598
+ {
599
+ "file": filename,
600
+ # "document_type": report_type,
601
+ "extracted_text": extracted_text,
602
+ "summary": summary,
603
+ "message": "Successful",
604
+ }
605
+ )
606
+ logger.info(f"✅ Finished processing file: {filename}")
607
+
608
+ if not extracted_data:
609
+ logger.warning("❌ No valid medical files processed")
610
+ return jsonify({"error": "No valid medical files processed"}), 400
611
+
612
+ logger.info("📦 Upload processing completed successfully")
613
+ return jsonify({"extracted_data": extracted_data}), 200
614
+
615
+
616
+ # # API Route: Extract Medical Data Based on Predefined Questions
617
+ # @app.route('/extract_medical_data', methods=['POST'])
618
+ # def extract_medical_data():
619
+ # data = request.json
620
+ # print(f"📥 Incoming request data: {data}")
621
+
622
+ # qa_model_name = data.get("qa_model_name")
623
+ # qa_model_type = data.get("qa_model_type")
624
+
625
+ # if "extracted_data" not in data:
626
+ # return jsonify({"error": "Missing 'extracted_data' in request"}), 400
627
+
628
+ # if not qa_model_name or not qa_model_type:
629
+ # return jsonify({"error": "Missing 'model_name' or 'model_type'"}), 400
630
+
631
+ # try:
632
+ # print(f"🌀 Loading model: {qa_model_name} ({qa_model_type})")
633
+ # qa_pipeline = pipeline(task=qa_model_type, model=qa_model_name)
634
+ # print(f"✅ Model loaded: {qa_pipeline.model.config._name_or_path}")
635
+ # except Exception as e:
636
+ # print("❌ Error loading model:", str(e))
637
+ # return jsonify({"error": f"Could not load model: {str(e)}"}), 500
638
+
639
+ # questions = {
640
+ # "Patient Name": "What is the patient's name?",
641
+ # "Age": "What is the patient's age?",
642
+ # "Gender": "What is the patient's gender?",
643
+ # "Date of Birth": "What is the patient's date of birth?",
644
+ # "Patient ID": "What is the patient ID?",
645
+ # "Reason for Visit": "What is the reason for the patient's visit?",
646
+ # "Physician": "Who is the physician in charge of the patient?",
647
+ # "Test Date": "What is the test date?",
648
+ # "Hemoglobin": "What is the patient's hemoglobin level?",
649
+ # "Blood Glucose (Fasting)": "What is the patient's fasting blood glucose level?",
650
+ # "Total Cholesterol": "What is the total cholesterol level?",
651
+ # "LDL Cholesterol": "What is the LDL cholesterol level?",
652
+ # "HDL Cholesterol": "What is the HDL cholesterol level?",
653
+ # "Serum Creatinine": "What is the serum creatinine level?",
654
+ # "Vitamin D (25-OH)": "What is the patient's Vitamin D level?",
655
+ # "Height": "What is the patient's height?",
656
+ # "Weight": "What is the patient's weight?",
657
+ # "Blood Pressure (Systolic)": "What is the patient's systolic blood pressure?",
658
+ # "Blood Pressure (Diastolic)": "What is the patient's diastolic blood pressure?",
659
+ # "Recommendations": "What are the recommendations based on the test results?"
660
+ # }
661
+
662
+ # structured_response = {"extracted_data": []}
663
+
664
+ # for file_data in data["extracted_data"]:
665
+ # filename = file_data["file"]
666
+ # context = file_data["extracted_text"]
667
+
668
+ # if not context:
669
+ # structured_response["extracted_data"].append({
670
+ # "file": filename,
671
+ # "medical_terms": "No data extracted"
672
+ # })
673
+ # continue
674
+
675
+ # # Prepare batch QA input
676
+ # qa_inputs = [
677
+ # {"question": q, "context": context}
678
+ # for q in questions.values()
679
+ # ]
680
+
681
+ # try:
682
+ # qa_outputs = qa_pipeline(qa_inputs)
683
+ # print("📤 Batch QA outputs:", qa_outputs)
684
+ # except Exception as e:
685
+ # print("⚠️ Batch failed, falling back to loop:", str(e))
686
+ # qa_outputs = [qa_pipeline(q) for q in qa_inputs]
687
+
688
+ # # Map answers back to questions
689
+ # extracted_info = {}
690
+ # for i, key in enumerate(questions.keys()):
691
+ # answer = qa_outputs[i].get("answer", "").strip()
692
+ # score = qa_outputs[i].get("score", 0.0)
693
+
694
+ # # If the model returns an empty string or very low confidence, mark as "Not Mentioned"
695
+ # if not answer or score < 0.1:
696
+ # extracted_info[key] = "Not Mentioned"
697
+ # else:
698
+ # extracted_info[key] = answer
699
+
700
+ # # Optional: Clean results
701
+ # # extracted_info = {k: clean_result(v) for k, v in extracted_info.items()}
702
+
703
+ # categorized_data = [
704
+ # {
705
+ # "name": "Patient Information",
706
+ # "fields": [
707
+ # {"label": "Patient Name", "value": extracted_info.get("Patient Name", "")},
708
+ # {"label": "Date of Birth", "value": extracted_info.get("Date of Birth", "")},
709
+ # {"label": "Gender", "value": extracted_info.get("Gender", "")},
710
+ # {"label": "Patient ID", "value": extracted_info.get("Patient ID", "")}
711
+ # ]
712
+ # },
713
+ # {
714
+ # "name": "Vitals",
715
+ # "fields": [
716
+ # {"label": "Height", "value": extracted_info.get("Height", "")},
717
+ # {"label": "Weight", "value": extracted_info.get("Weight", "")},
718
+ # {"label": "Blood Pressure", "value": f"{extracted_info.get('Blood Pressure (Systolic)', '')}/{extracted_info.get('Blood Pressure (Diastolic)', '')} mmHg"},
719
+ # {"label": "Hemoglobin", "value": extracted_info.get("Hemoglobin", "")},
720
+ # {"label": "Serum Creatinine", "value": extracted_info.get("Serum Creatinine", "")}
721
+ # ]
722
+ # },
723
+ # {
724
+ # "name": "Lab Results",
725
+ # "fields": [
726
+ # {"label": "Blood Glucose (Fasting)", "value": extracted_info.get("Blood Glucose (Fasting)", "")},
727
+ # {"label": "Total Cholesterol", "value": extracted_info.get("Total Cholesterol", "")},
728
+ # {"label": "LDL Cholesterol", "value": extracted_info.get("LDL Cholesterol", "")},
729
+ # {"label": "HDL Cholesterol", "value": extracted_info.get("HDL Cholesterol", "")},
730
+ # {"label": "Vitamin D (25-OH)", "value": extracted_info.get("Vitamin D (25-OH)", "")}
731
+ # ]
732
+ # },
733
+ # {
734
+ # "name": "Medical Notes",
735
+ # "fields": [
736
+ # {"label": "Reason for Visit", "value": extracted_info.get("Reason for Visit", "")},
737
+ # {"label": "Physician", "value": extracted_info.get("Physician", "")},
738
+ # {"label": "Test Date", "value": extracted_info.get("Test Date", "")},
739
+ # {"label": "Recommendations", "value": extracted_info.get("Recommendations", "")}
740
+ # ]
741
+ # }
742
+ # ]
743
+
744
+ # structured_response["extracted_data"].append({
745
+ # "file": filename,
746
+ # "medical_terms": extracted_info,
747
+ # "categorized_data": categorized_data
748
+ # })
749
+
750
+ # save_data_to_storage(filename, structured_response)
751
+ # print(f"✅ Extracted data saved to: {os.path.join(UPLOAD_FOLDER, f'{filename}.json')}")
752
+
753
+ # return jsonify(structured_response)
754
+
755
+
756
+ # ------------------ CLEAN FUNCTION ------------------ #
757
+ @log_execution_time()
758
+ def clean_result(value):
759
+ logger.debug("Cleaning value: %s", value)
760
+ if isinstance(value, str):
761
+ value = re.sub(r"\s+", " ", value)
762
+ value = re.sub(r"[-_:]+", " ", value)
763
+ value = re.sub(r"[^\x00-\x7F]+", " ", value)
764
+ value = re.sub(
765
+ r"(?<=\d),(?=\d)", "", value
766
+ ) # Remove commas in numbers like 250,000
767
+ return value.strip() if value.strip() else "Not Available"
768
+ elif isinstance(value, list):
769
+ cleaned = [clean_result(v) for v in value if v is not None]
770
+ return cleaned if cleaned else ["Not Available"]
771
+ elif isinstance(value, dict):
772
+ return {k: clean_result(v) for k, v in value.items()}
773
+ return value
774
+
775
+ # ------------------Group by Category ------------------ #
776
+ @log_execution_time()
777
+ def group_by_category(data):
778
+ logger.info("Grouping extracted items by category")
779
+ grouped = defaultdict(list)
780
+ category_times = {}
781
+
782
+ for item in data:
783
+ cat = item.get("category", "General")
784
+ start_time = time.time()
785
+ grouped[cat].append(
786
+ {
787
+ "question": item.get("question", "Not Created"),
788
+ "label": item.get("label", "Unknown"),
789
+ "answer": item.get("answer", "Not Available"),
790
+ }
791
+ )
792
+ elapsed = time.time() - start_time
793
+ category_times[cat] = category_times.get(cat, 0) + elapsed
794
+
795
+ for cat, details in grouped.items():
796
+ logger.info(f"📂 Category '{cat}': {len(details)} items, time taken: {category_times[cat]:.4f}s")
797
+
798
+ return [{"category": k, "detail": v} for k, v in grouped.items()]
799
+
800
+
801
+ # ------------------detect duplicate to remove it ------------------ #
802
+ @log_execution_time()
803
+ def deduplicate_extractions(data):
804
+ logger.info("Deduplicating extracted data")
805
+ seen = set()
806
+ unique = []
807
+ for item in data:
808
+ # Use a tuple of key fields to detect duplicates
809
+ key = (item.get("label"))
810
+ if key not in seen:
811
+ seen.add(key)
812
+ unique.append(item)
813
+ return unique
814
+
815
+
816
+ # Load tokenizer outside the route for performance
817
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
818
+
819
+
820
+ # -----------------------------Split text into overlapping chunks---------------#
821
+ @log_execution_time()
822
+ def chunk_text(text, tokenizer, max_tokens=512, overlap=50):
823
+ """
824
+ Splits text into overlapping token-based chunks without using NLTK.
825
+
826
+ Args:
827
+ text (str): Raw input text.
828
+ tokenizer (transformers tokenizer): Hugging Face tokenizer instance.
829
+ max_tokens (int): Max tokens per chunk.
830
+ overlap (int): Number of overlapping tokens between chunks.
831
+
832
+ Returns:
833
+ List[str]: List of decoded text chunks.
834
+ """
835
+ # Tokenize the full text
836
+ logger.info("Splitting text into chunks")
837
+ input_ids = tokenizer.encode(text, add_special_tokens=False)
838
+ chunks = []
839
+ start = 0
840
+ while start < len(input_ids):
841
+ end = start + max_tokens
842
+ chunk_ids = input_ids[start:end]
843
+ chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True)
844
+ # Ensure partial continuation isn't cut off mid-sentence
845
+ if not chunk_text.endswith(('.', '?', '!', ':')):
846
+ chunk_text += "..."
847
+
848
+ chunks.append(chunk_text)
849
+ start += max_tokens - overlap
850
+ logger.info("Created %d chunks", len(chunks))
851
+ return chunks
852
+
853
+ # ------------------ PARSE JSON OBJECTS FROM OUTPUT ------------------ #
854
+ @log_execution_time()
855
+ def extract_json_objects(text):
856
+ logger.info("Extracting JSON objects from text")
857
+ extracted = []
858
+ try:
859
+ json_start = text.index('[')
860
+ json_text = text[json_start:]
861
+ except ValueError:
862
+ logger.warning("⚠ '[' not found in output")
863
+ return []
864
+
865
+ # Try parsing full array first
866
+ try:
867
+ parsed = json.loads(json_text)
868
+ if isinstance(parsed, list):
869
+ return parsed
870
+ except Exception:
871
+ pass # fallback to manual parsing
872
+
873
+ # Manual recovery via brace matching
874
+ stack = 0
875
+ obj_start = None
876
+ for i, char in enumerate(json_text):
877
+ if char == '{':
878
+ if stack == 0:
879
+ obj_start = i
880
+ stack += 1
881
+ elif char == '}':
882
+ stack -= 1
883
+ if stack == 0 and obj_start is not None:
884
+ obj_str = json_text[obj_start:i+1]
885
+ try:
886
+ obj = json.loads(obj_str)
887
+ extracted.append(obj)
888
+ except Exception as e:
889
+ logger.error(f"❌ Invalid JSON object: {e}")
890
+ obj_start = None
891
+
892
+ return extracted
893
+
894
+
895
+ # ------------------ PROCESS A SINGLE CHUNK ------------------ #
896
+ @log_execution_time()
897
+ def process_chunk(generator, chunk, idx):
898
+ logger.info("Processing chunk %d", idx + 1)
899
+ prompt = f"""
900
+ [INST] <<SYS>>
901
+ You are a clinical data extraction assistant.
902
+
903
+ Your job is to:
904
+ 1. Read the following medical report.
905
+ 2. Extract all medically relevant facts as a list of JSON objects.
906
+ 3. Each object must include:
907
+ - "label": a short field name (e.g., "blood pressure", "diagnosis")
908
+ - "question": a question related to that field
909
+ - "answer": the answer from the text
910
+ 4. After extracting the list, categorize each object under one of the following fixed categories:
911
+
912
+ - Patient Info
913
+ - Vitals
914
+ - Symptoms
915
+ - Allergies
916
+ - Habits
917
+ - Comorbidities
918
+ - Diagnosis
919
+ - Medication
920
+ - Laboratory
921
+ - Radiology
922
+ - Doctor Note
923
+
924
+ Example format for structure only — do not include in output:
925
+ [
926
+ {{
927
+ "label": "patient name",
928
+ "question": "What is the patient's name?",
929
+ "answer": "John Doe",
930
+ "category": "Patient Info"
931
+ }},
932
+ {{
933
+ "label": "heart rate",
934
+ "question": "What is the heart rate?",
935
+ "answer": "78 bpm",
936
+ "category": "Vitals"
937
+ }}
938
+ ]
939
+
940
+ ⚠ Use these categories listed above.If an item does not fit any of these categories, create a new category for it.
941
+
942
+ Text:
943
+ {chunk}
944
+
945
+ Return a single valid JSON array of all extracted objects.
946
+ Do not include any explanations or commentary.
947
+ Only output the JSON array
948
+ <</SYS>> [/INST]
949
+ """
950
+
951
+ try:
952
+ output = generator(
953
+ prompt,
954
+ max_new_tokens=1024,
955
+ do_sample=True,
956
+ temperature=0.3
957
+ )[0]["generated_text"]
958
+ print("----------------------------------")
959
+ logger.info(f"📤 Output from chunk {idx}: {output}...")
960
+ return idx, output
961
+ except Exception as e:
962
+ logger.error("Error processing chunk %d: %s", idx, e)
963
+ return idx, None
964
+
965
+
966
+ # ------------------Extract Medical Data ------------------ #
967
+ @app.route("/extract_medical_data", methods=["POST"])
968
+ @log_execution_time()
969
+ def extract_medical_data():
970
+ data = request.json
971
+ logger.info("Received request: %s", json.dumps(data, indent=2))
972
+
973
+ qa_model_name = data.get("qa_model_name")
974
+ qa_model_type = data.get("qa_model_type")
975
+ extracted_files = data.get("extracted_data")
976
+
977
+ if not qa_model_name or not qa_model_type:
978
+ return jsonify({"error": "Missing 'qa_model_name' or 'qa_model_type'"}), 400
979
+
980
+ if not extracted_files:
981
+ return jsonify({"error": "Missing 'extracted_data' in request"}), 400
982
+
983
+ try:
984
+ logger.info(f"🌀 Loading model: {qa_model_name} ({qa_model_type})")
985
+ model = AutoModelForCausalLM.from_pretrained(qa_model_name, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
986
+ generator = pipeline(task=qa_model_type, model=model, tokenizer=tokenizer)
987
+ logger.info(f"✅ Model loaded successfully: {generator.model.config._name_or_path}")
988
+ except Exception as e:
989
+ logger.error("❌ Model load failure")
990
+ return jsonify({"error": f"Could not load model: {str(e)}"}), 500
991
+
992
+ structured_response = {"extracted_data": []}
993
+
994
+ for file_data in extracted_files:
995
+ filename = file_data.get("file", "unknown_file")
996
+ context = file_data.get("extracted_text", "").strip()
997
+ logger.info("Processing file: %s", filename)
998
+
999
+ if not context:
1000
+ logger.warning("No text found in file: %s", filename)
1001
+ structured_response["extracted_data"].append(
1002
+ {"file": filename, "medical_fields": "No data extracted"}
1003
+ )
1004
+ continue
1005
+
1006
+ chunks = chunk_text(context, tokenizer)
1007
+ logger.info(f"📚 Chunked into {len(chunks)} parts for {filename}")
1008
+
1009
+ all_extracted = []
1010
+ # for idx,chunk in enumerate(chunks):
1011
+ # print(f"Processing chunk {idx+1}/{len(chunks)}")
1012
+
1013
+ with ThreadPoolExecutor(max_workers=4) as executor:
1014
+ futures = {
1015
+ executor.submit(process_chunk, generator, chunk, idx): idx
1016
+ for idx, chunk in enumerate(chunks)
1017
+ }
1018
+ for future in as_completed(futures):
1019
+ idx = futures[future]
1020
+ _, output = future.result()
1021
+
1022
+ if not output:
1023
+ continue
1024
+
1025
+ try:
1026
+ objs = extract_json_objects(output)
1027
+ if objs:
1028
+ all_extracted.extend(objs)
1029
+ else:
1030
+ logger.error(f"⚠ Chunk {idx+1} yielded no valid JSON.")
1031
+ except Exception as e:
1032
+ logger.error(f"❌ Error extracting JSON from chunk {idx+1}")
1033
+
1034
+ # Clean and group results for this file
1035
+ if all_extracted:
1036
+ deduped = deduplicate_extractions(all_extracted)
1037
+ # cleaned_json = clean_result()
1038
+ grouped_data = group_by_category(deduped)
1039
+ else:
1040
+ grouped_data = {"error": "No valid data extracted"}
1041
+
1042
+ structured_response["extracted_data"].append(
1043
+ {"file": filename, "medical_fields": grouped_data}
1044
+ )
1045
+
1046
+ try:
1047
+ save_data_to_storage(filename, grouped_data)
1048
+ except Exception as e:
1049
+ logger.error(f"⚠ Failed to save data for {filename}: {e}")
1050
+
1051
+ logger.info("✅ Extraction complete.")
1052
+ return jsonify(structured_response)
1053
+
1054
+
1055
+
1056
+ # -------------------------- save data to a JSON file----------------------#
1057
+ @log_execution_time()
1058
+ def save_data_to_storage(filename, data):
1059
+ try:
1060
+ filename = filename.rsplit(".", 1)[0] # Remove extension
1061
+ filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json")
1062
+ logger.info(f"💾 Saving to: {filepath}")
1063
+ with open(filepath, "w") as file:
1064
+ json.dump(data, file)
1065
+ logger.info(f"✅ Data saved successfully to {filepath}")
1066
+ except Exception as e:
1067
+ logger.error(f"🚨 Exception during save: {e}")
1068
+
1069
+
1070
+ # Function to get data from a JSON file
1071
+ # 🔍 Get data from storage
1072
+ @log_execution_time()
1073
+ def get_data_from_storage(filename):
1074
+ try:
1075
+ filepath = os.path.join(UPLOAD_FOLDER, f"{filename}.json")
1076
+ logger.info(f"🔍 Looking for file at: {filepath}")
1077
+ if not os.path.exists(filepath):
1078
+ logger.warning(f"🚫 File not found at: {filepath}")
1079
+ return None
1080
+ with open(filepath, "r") as file:
1081
+ data = json.load(file)
1082
+ logger.info(f"✅ File found and loaded: {filepath}")
1083
+ return data
1084
+ except Exception as e:
1085
+ logger.error(f"🚨 Error loading data: {e}")
1086
+ return None
1087
+
1088
+
1089
+ # 🔹 Fetch updated medical data
1090
+ @app.route("/get_updated_medical_data", methods=["GET"])
1091
+ @log_execution_time()
1092
+ def get_updated_data():
1093
+ file_name = request.args.get("file")
1094
+
1095
+ if not file_name:
1096
+ return jsonify({"error": "File name is required"}), 400
1097
+
1098
+ # 🔥 Strip extension if present
1099
+ file_name = file_name.rsplit(".", 1)[0]
1100
+
1101
+ # ✅ Load updated JSON data from storage
1102
+ updated_data = get_data_from_storage(file_name)
1103
+
1104
+ if updated_data:
1105
+ return jsonify({"file": file_name, "data": updated_data}), 200
1106
+ else:
1107
+ return jsonify({"error": f"File '{file_name}' not found"}), 404
1108
+
1109
+
1110
+
1111
+ @app.route("/update_medical_data", methods=["PUT"])
1112
+ @log_execution_time()
1113
+ def update_medical_data():
1114
+ try:
1115
+ data = request.json
1116
+ logger.info("Received update: %s", json.dumps(data, indent=2))
1117
+
1118
+ filename = data.get("file", "").rsplit(".", 1)[0] # Strip extension like .pdf
1119
+ updates = data.get("updates", [])
1120
+
1121
+ if not filename or not updates:
1122
+ return jsonify({"error": "File name or updates missing"}), 400
1123
+
1124
+ # Load current stored data
1125
+ existing_data = get_data_from_storage(filename)
1126
+ if not existing_data:
1127
+ return jsonify({"error": f"File '{filename}' not found"}), 404
1128
+
1129
+ # Loop through updates and modify categorized_data
1130
+ for update in updates:
1131
+ category = update.get("category")
1132
+ field = update.get("field")
1133
+ new_value = update.get("value")
1134
+ updated = False
1135
+
1136
+ for extracted in existing_data.get("extracted_data", []):
1137
+ for cat in extracted.get("categorized_data", []):
1138
+ if cat.get("name") == category:
1139
+ for fld in cat.get("fields", []):
1140
+ if fld.get("label") == field:
1141
+ logger.info("Updating [%s] %s → %s", category, field, new_value)
1142
+ fld["value"] = new_value
1143
+ updated = True
1144
+ break
1145
+ if updated:
1146
+ break
1147
+ if updated:
1148
+ break
1149
+
1150
+ # 🧠 Sync medical_terms with categorized_data
1151
+ for extracted in existing_data.get("extracted_data", []):
1152
+ if "categorized_data" in extracted:
1153
+ new_terms = {}
1154
+ for category in extracted["categorized_data"]:
1155
+ for field in category.get("fields", []):
1156
+ label = field.get("label")
1157
+ value = field.get("value", "")
1158
+ new_terms[label] = value
1159
+ extracted["medical_terms"] = new_terms
1160
+ logger.info("Synced 'medical_terms' with 'categorized_data'")
1161
+
1162
+ # Save updated data to file
1163
+ save_data_to_storage(filename, existing_data)
1164
+ logger.info("✅ Updated data saved successfully")
1165
+
1166
+ return (
1167
+ jsonify(
1168
+ {"message": "Data updated successfully", "updated_data": existing_data}
1169
+ ),
1170
+ 200,
1171
+ )
1172
+
1173
+ except Exception as e:
1174
+ logger.error("Update error: %s", e)
1175
+ return jsonify({"error": str(e)}), 500
1176
+
1177
+ # Test Route
1178
+ @app.route("/")
1179
+ def home():
1180
+ return "Medical Data Extraction API is running!"
1181
+
1182
+
1183
+ if __name__ == "__main__":
1184
+ app.run(host="0.0.0.0", port=5000, debug=True)
1185
+ # if __name__ == '__main__':
1186
+ # from gevent.pywsgi import WSGIServer # type: ignore
1187
+ # http_server = WSGIServer(('0.0.0.0', 5000), app)
1188
+ # http_server.serve_forever()
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ werkzeug
4
+ transformers
5
+ whisper
6
+ python-dotenv
7
+ torch==2.6.0
8
+ pillow
9
+ pdf2image
10
+ python-docx
11
+ openpyxl
12
+ pytesseract
13
+ scikit-learn
14
+ scipy
15
+ pandas
16
+ numpy
speech_to_chart.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import json
4
+ import os
5
+ import re
6
+ import logging
7
+ import shutil
8
+ from flask import Flask, request, jsonify, abort
9
+ from werkzeug.utils import secure_filename
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
+ import torch
12
+ import whisper
13
+ from dotenv import load_dotenv
14
+ import pytesseract
15
+ import cv2
16
+ import pdfplumber
17
+ import pandas as pd
18
+ from PIL import Image
19
+ from docx import Document
20
+ from flask_cors import CORS
21
+
22
+ # Load environment variables
23
+ load_dotenv()
24
+
25
+ # Initialize Flask app
26
+ app = Flask(__name__)
27
+ CORS(app)
28
+
29
+ # Configure logging
30
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
31
+
32
+ # Configure upload directory and max file size
33
+ UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads'))
34
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
35
+ app.config['UPLOAD_FOLDER'] = UPLOAD_DIR
36
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max file size
37
+
38
+ # Allowed file extensions
39
+ ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'flac'}
40
+ ALLOWED_DOCUMENT_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'docx', 'xlsx', 'xls'}
41
+
42
+ # Ensure ffmpeg is in PATH
43
+ ffmpeg_path = shutil.which("ffmpeg") or "C:\\ffmpeg\\bin\\ffmpeg.exe"
44
+ if not os.path.exists(ffmpeg_path):
45
+ raise RuntimeError("FFmpeg not found! Please install FFmpeg and set the correct path.")
46
+ os.environ["PATH"] += os.pathsep + os.path.dirname(ffmpeg_path)\
47
+
48
+ def allowed_file(filename, allowed_extensions):
49
+ """Check if the file extension is allowed."""
50
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions
51
+
52
+ class LazyModelLoader:
53
+ def __init__(self, model_name, task, tokenizer=None, apply_quantization=False):
54
+ self.model_name = model_name
55
+ self.task = task
56
+ self.tokenizer = tokenizer
57
+ self.apply_quantization = apply_quantization
58
+ self._pipeline = None
59
+
60
+ def load(self):
61
+ if self._pipeline is None:
62
+ logging.info(f"Loading pipeline for task: {self.task} | model: {self.model_name}")
63
+ if self.task == "question-answering":
64
+ model = AutoModelForCausalLM.from_pretrained(self.model_name)
65
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
66
+ if self.apply_quantization:
67
+ logging.info("Applying quantization...")
68
+ model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
69
+ self._pipeline = pipeline(self.task, model=model, tokenizer=tokenizer)
70
+ else:
71
+ self._pipeline = pipeline(self.task, model=self.model_name, tokenizer=self.tokenizer)
72
+ return self._pipeline
73
+
74
+ # PHI scrubbing agent
75
+ class PHIScrubberAgent:
76
+ @staticmethod
77
+ def scrub_phi(text):
78
+ try:
79
+ text = re.sub(r'\b(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b', '[PHONE]', text)
80
+ text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w{2,4}\b', '[EMAIL]', text)
81
+ text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text)
82
+ text = re.sub(r'\b\d{1,5}\s+\w+\s+(Street|St|Avenue|Ave|Boulevard|Blvd|Road|Rd|Lane|Ln)\b', '[ADDRESS]',
83
+ text, flags=re.IGNORECASE)
84
+ text = re.sub(r'\bDr\.?\s+[A-Z][a-z]+\s+[A-Z][a-z]+\b', 'Dr. [NAME]', text)
85
+ text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text)
86
+ except Exception as e:
87
+ logging.error(f"PHI scrubbing failed: {e}")
88
+ return text
89
+
90
+ # Summarization Agent
91
+ class SummarizerAgent:
92
+ def __init__(self, summarization_model_loader):
93
+ self.summarization_model_loader = summarization_model_loader
94
+
95
+ def generate_summary(self, text):
96
+ model = self.summarization_model_loader.load()
97
+ try:
98
+ summary_result = model(text, max_length=150, min_length=30, do_sample=False)
99
+ return summary_result[0]['summary_text'].strip()
100
+ except Exception as e:
101
+ logging.error(f"Summary generation failed: {e}")
102
+ return "Summary generation failed."
103
+
104
+ # Medical Data Extraction Agent
105
+ class MedicalDataExtractorAgent:
106
+ def __init__(self, gen_model_loader):
107
+ self.gen_model_loader = gen_model_loader
108
+
109
+ def extract_medical_data(self, text):
110
+ try:
111
+ generator = self.gen_model_loader.load()
112
+ prompt = (
113
+ "Extract structured medical information from the following clinical note.\n\n"
114
+ "Return the result in JSON format with the following fields:\n"
115
+ "patient_condition, symptoms, current_problems, allergies, dr_notes, "
116
+ "prescription, investigations, follow_up_instructions.\n\n"
117
+ f"Clinical Note:\n{text}\n\n"
118
+ "Structured JSON Output:\n"
119
+ )
120
+ response = generator(prompt, max_new_tokens=256)[0]["generated_text"]
121
+ logging.debug(f"Raw model output: {response}")
122
+
123
+ json_start = response.find("{")
124
+ json_end = response.rfind("}") + 1
125
+ if json_start == -1 or json_end == -1:
126
+ raise ValueError("No JSON found in the model response.")
127
+
128
+ json_str = response[json_start:json_end]
129
+ return json.loads(json_str)
130
+
131
+ except Exception as e:
132
+ logging.error(f"Error extracting medical data: {e}")
133
+ return {"error": f"Failed to extract medical data: {str(e)}"}
134
+
135
+ # Initialize lazy loaders
136
+ gen_model_loader = LazyModelLoader(
137
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
138
+ "text-generation",
139
+ )
140
+ summarization_model_loader = LazyModelLoader("google-t5/t5-large", "summarization", apply_quantization=True)
141
+ whisper_model = whisper.load_model("base")
142
+
143
+ # Initialize agents
144
+ phi_scrubber_agent = PHIScrubberAgent()
145
+ medical_data_extractor_agent = MedicalDataExtractorAgent(gen_model_loader)
146
+ summarizer_agent = SummarizerAgent(summarization_model_loader)
147
+
148
+ # API Endpoints
149
+ @app.route('/api/extract_medical_data', methods=['POST'])
150
+ def extract_medical_data():
151
+ try:
152
+ data = request.json
153
+ if "text" not in data or not data["text"].strip():
154
+ return jsonify({"error": "No valid text provided"}), 400
155
+ raw_text = data["text"]
156
+ clean_text = phi_scrubber_agent.scrub_phi(raw_text)
157
+ structured_data = medical_data_extractor_agent.extract_medical_data(clean_text)
158
+ return jsonify(structured_data), 200
159
+ except Exception as e:
160
+ logging.error(f"Failed to extract medical data: {e}")
161
+ return jsonify({"error": f"Extraction Error: {str(e)}"}), 500
162
+
163
+ @app.route('/api/transcribe', methods=['POST'])
164
+ def transcribe_audio():
165
+ if 'audio' not in request.files:
166
+ abort(400, description="No audio file provided")
167
+ audio_file = request.files['audio']
168
+ if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
169
+ abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
170
+ filename = secure_filename(audio_file.filename)
171
+ audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
172
+ audio_file.save(audio_path)
173
+ try:
174
+ result = whisper_model.transcribe(audio_path)
175
+ transcribed_text = result["text"]
176
+ os.remove(audio_path)
177
+ return jsonify({"transcribed_text": transcribed_text}), 200
178
+ except Exception as e:
179
+ logging.error(f"Transcription failed: {str(e)}")
180
+ return jsonify({"error": f"Transcription failed: {str(e)}"}), 500
181
+
182
+ @app.route('/api/generate_summary', methods=['POST'])
183
+ def generate_summary():
184
+ data = request.json
185
+ if "text" not in data or not data["text"].strip():
186
+ return jsonify({"error": "No valid text provided"}), 400
187
+ context = data["text"]
188
+ clean_text = phi_scrubber_agent.scrub_phi(context)
189
+ summary = summarizer_agent.generate_summary(clean_text)
190
+ return jsonify({"summary": summary}), 200
191
+
192
+ @app.route('/api/extract_medical_data_from_audio', methods=['POST'])
193
+ def extract_medical_data_from_audio():
194
+ if 'audio' not in request.files:
195
+ abort(400, description="No audio file provided")
196
+ audio_file = request.files['audio']
197
+ if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
198
+ abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
199
+ filename = secure_filename(audio_file.filename)
200
+ audio_path = os.path.join(UPLOAD_DIR, filename)
201
+ audio_file.save(audio_path)
202
+ try:
203
+ result = whisper_model.transcribe(audio_path)
204
+ transcribed_text = result["text"]
205
+ clean_text = phi_scrubber_agent.scrub_phi(transcribed_text)
206
+ summary = summarizer_agent.generate_summary(clean_text)
207
+ structured_data = medical_data_extractor_agent.extract_medical_data(clean_text)
208
+ response = {
209
+ "transcribed_text": clean_text,
210
+ "summary": summary,
211
+ "medical_chart": structured_data
212
+ }
213
+ os.remove(audio_path)
214
+ return jsonify(response), 200
215
+ except Exception as e:
216
+ logging.error(f"Processing failed: {str(e)}")
217
+ return jsonify({"error": f"Processing failed: {str(e)}"}), 500
218
+
219
+ if __name__ == '__main__':
220
+ app.run(host='0.0.0.0', port=5000, debug=False)
221
+
222
+ # import json
223
+ # import os
224
+ # import re
225
+ # import logging
226
+ # import shutil
227
+ # from dotenv import load_dotenv
228
+ # from flask import Flask, request, jsonify, abort
229
+ # from werkzeug.utils import secure_filename
230
+ # from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
231
+ # import pytesseract
232
+ # import cv2
233
+ # import pdfplumber
234
+ # import pandas as pd
235
+ # from PIL import Image
236
+ # from docx import Document
237
+ # from flask_cors import CORS
238
+ # from flask_executor import Executor
239
+ # from sentence_transformers import SentenceTransformer
240
+ # import faiss
241
+ # import whisper
242
+ # from PyPDF2 import PdfReader
243
+ # from pdf2image import convert_from_path
244
+ # from concurrent.futures import ThreadPoolExecutor
245
+ # import tempfile
246
+
247
+ # # Load environment variables
248
+ # load_dotenv()
249
+
250
+ # # Initialize Flask app
251
+ # app = Flask(__name__)
252
+ # CORS(app)
253
+
254
+ # # Configure logging
255
+ # logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
256
+
257
+ # # Configure upload directory and max file size
258
+ # UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads'))
259
+ # os.makedirs(UPLOAD_DIR, exist_ok=True)
260
+ # app.config['UPLOAD_FOLDER'] = UPLOAD_DIR
261
+ # app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max file size
262
+
263
+ # # Initialize Flask-Executor for asynchronous tasks
264
+ # executor = Executor(app)
265
+ # whisper_model = whisper.load_model("tiny")
266
+ # # Allowed file extensions
267
+ # ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'flac'}
268
+ # ALLOWED_DOCUMENT_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'docx', 'xlsx', 'xls'}
269
+
270
+ # # Ensure ffmpeg is in PATH
271
+ # ffmpeg_path = shutil.which("ffmpeg") or "C:\\ffmpeg\\bin\\ffmpeg.exe"
272
+ # if not os.path.exists(ffmpeg_path):
273
+ # raise RuntimeError("FFmpeg not found! Please install FFmpeg and set the correct path.")
274
+ # os.environ["PATH"] += os.pathsep + os.path.dirname(ffmpeg_path)
275
+
276
+ # # Lazy model loading to save resources
277
+ # class LazyModelLoader:
278
+ # def __init__(self, model_name, task, tokenizer=None):
279
+ # self.model_name = model_name
280
+ # self.task = task
281
+ # self.tokenizer = tokenizer
282
+ # self._model = None
283
+
284
+ # def load(self):
285
+ # """Load the model if not already loaded."""
286
+ # if self._model is None:
287
+ # logging.info(f"Loading model: {self.model_name}")
288
+ # if self.task == "text-generation":
289
+ # self._model = AutoModelForCausalLM.from_pretrained(
290
+ # self.model_name, device_map="auto", torch_dtype="auto"
291
+ # )
292
+ # self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, legacy=False)
293
+ # # Set pad_token_id if it's not already set
294
+ # if self._model.generation_config.pad_token_id is None or self._model.generation_config.pad_token_id < 0:
295
+ # if self._tokenizer.eos_token_id is not None:
296
+ # self._model.generation_config.pad_token_id = self._tokenizer.eos_token_id
297
+ # logging.info(f"Set pad_token_id to {self._tokenizer.eos_token_id}")
298
+ # else:
299
+ # logging.warning("No valid eos_token_id found. Setting pad_token_id to 0 as a fallback.")
300
+ # self._model.generation_config.pad_token_id = 0
301
+ # else:
302
+ # self._model = pipeline(self.task, model=self.model_name, tokenizer=self.tokenizer)
303
+ # return self._model
304
+
305
+ # # Text extraction agents
306
+ # class TextExtractorAgent:
307
+ # @staticmethod
308
+ # def extract_text(filepath, ext):
309
+ # """Extract text based on file type."""
310
+ # try:
311
+ # if ext == "pdf":
312
+ # return TextExtractorAgent.extract_text_from_pdf(filepath)
313
+ # elif ext in {"jpg", "jpeg", "png"}:
314
+ # return TextExtractorAgent.extract_text_from_image(filepath)
315
+ # elif ext == "docx":
316
+ # return TextExtractorAgent.extract_text_from_docx(filepath)
317
+ # elif ext in {"xlsx", "xls"}:
318
+ # return TextExtractorAgent.extract_text_from_excel(filepath)
319
+ # return None
320
+ # except Exception as e:
321
+ # logging.error(f"Text extraction failed: {e}")
322
+ # return None
323
+
324
+ # @staticmethod
325
+ # def extract_text_from_pdf(filepath):
326
+ # """Extract text from a PDF file."""
327
+ # text = ""
328
+ # with pdfplumber.open(filepath) as pdf:
329
+ # for page in pdf.pages:
330
+ # page_text = page.extract_text()
331
+ # if page_text:
332
+ # text += page_text + "\n"
333
+ # return text.strip() or None
334
+
335
+ # @staticmethod
336
+ # def extract_text_from_image(filepath):
337
+ # """Extract text from an image using OCR."""
338
+ # image = cv2.imread(filepath)
339
+ # gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
340
+ # _, processed = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
341
+ # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
342
+ # processed_path = temp_file.name
343
+ # cv2.imwrite(processed_path, processed)
344
+ # text = pytesseract.image_to_string(Image.open(processed_path), lang='eng')
345
+ # os.remove(processed_path)
346
+ # return text.strip() or None
347
+
348
+ # @staticmethod
349
+ # def extract_text_from_docx(filepath):
350
+ # """Extract text from a DOCX file."""
351
+ # doc = Document(filepath)
352
+ # text = "\n".join([para.text for para in doc.paragraphs])
353
+ # return text.strip() or None
354
+
355
+ # @staticmethod
356
+ # def extract_text_from_excel(filepath):
357
+ # """Extract text from an Excel file."""
358
+ # dfs = pd.read_excel(filepath, sheet_name=None)
359
+ # text = "\n".join([
360
+ # "\n".join([
361
+ # " ".join(map(str, df[col].dropna()))
362
+ # for col in df.columns
363
+ # ])
364
+ # for df in dfs.values()
365
+ # ])
366
+ # return text.strip() or None
367
+
368
+ # # PHI scrubbing agent
369
+ # class PHIScrubberAgent:
370
+ # @staticmethod
371
+ # def scrub_phi(text):
372
+ # """Remove sensitive personal health information (PHI)."""
373
+ # try:
374
+ # text = re.sub(r'\b(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b', '[PHONE]', text)
375
+ # text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w{2,4}\b', '[EMAIL]', text)
376
+ # text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text)
377
+ # text = re.sub(r'\b\d{1,5}\s+\w+\s+(Street|St|Avenue|Ave|Boulevard|Blvd|Road|Rd|Lane|Ln)\b', '[ADDRESS]', text, flags=re.IGNORECASE)
378
+ # text = re.sub(r'\bDr\.?\s+[A-Z][a-z]+\s+[A-Z][a-z]+\b', 'Dr. [NAME]', text)
379
+ # text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text)
380
+ # except Exception as e:
381
+ # logging.error(f"PHI scrubbing failed: {e}")
382
+ # return text
383
+
384
+ # # Summarization agent
385
+ # class SummarizerAgent:
386
+ # def __init__(self, summarization_model_loader):
387
+ # self.summarization_model_loader = summarization_model_loader
388
+
389
+ # def generate_summary(self, text):
390
+ # """Generate a summary of the provided text."""
391
+ # model = self.summarization_model_loader.load()
392
+ # try:
393
+ # summary_result = model(text, do_sample=False)
394
+ # return summary_result[0]['summary_text'].strip()
395
+ # except Exception as e:
396
+ # logging.error(f"Summary generation failed: {e}")
397
+ # return "Summary generation failed."
398
+
399
+ # def allowed_file(filename, allowed_extensions):
400
+ # """Check if the file extension is allowed."""
401
+ # return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions
402
+ # # Knowledge Base
403
+ # class KnowledgeBase:
404
+ # def __init__(self, documents):
405
+ # self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
406
+ # self.documents = documents
407
+ # self.embeddings = self.embedding_model.encode(documents)
408
+ # self.dimension = self.embedding_model.get_sentence_embedding_dimension()
409
+ # self.index = faiss.IndexFlatL2(self.dimension)
410
+ # self.index.add(self.embeddings)
411
+
412
+ # def retrieve_relevant_info(self, query, top_k=3):
413
+ # """Retrieve relevant medical information from the knowledge base."""
414
+ # query_embedding = self.embedding_model.encode([query])
415
+ # distances, indices = self.index.search(query_embedding, top_k)
416
+ # relevant_texts = [self.documents[i] for i in indices[0]]
417
+ # return relevant_texts
418
+ # # Medical data extraction agent
419
+ # class MedicalDataExtractorAgent:
420
+ # def __init__(self, model_loader, knowledge_base):
421
+ # self.model_loader = model_loader
422
+ # self.knowledge_base = knowledge_base
423
+
424
+ # def retrieve_relevant_info(self, query, top_k=3):
425
+ # """Retrieve relevant medical information from the knowledge base."""
426
+ # query_embedding = self.knowledge_base.embedding_model.encode([query])
427
+ # distances, indices = self.knowledge_base.index.search(query_embedding, top_k)
428
+ # relevant_texts = [self.knowledge_base.documents[i] for i in indices[0]]
429
+ # return relevant_texts
430
+
431
+ # def extract_medical_data(self, text):
432
+ # """Extract structured medical data from text using Agentic RAG."""
433
+ # try:
434
+ # # Define the default JSON schema
435
+ # default_schema = {
436
+ # "patient_name": "[NAME]",
437
+ # "age": None,
438
+ # "gender": None,
439
+ # "diagnosis": [],
440
+ # "symptoms": [],
441
+ # "medications": [],
442
+ # "allergies": [],
443
+ # "vitals": {
444
+ # "blood_pressure": None,
445
+ # "heart_rate": None,
446
+ # "temperature": None
447
+ # },
448
+ # "notes": ""
449
+ # }
450
+ # # Construct the prompt with the input text
451
+ # prompt = f"""
452
+ # ### Instruction:
453
+ # Extract structured medical data from the following text as a JSON whose parameters are enclosed in "" and without any \.
454
+ # The JSON should include patientname, age, gender, medications, allergies, diagnosis, symptoms, vitals, and notes.
455
+ # ### Text:
456
+ # {text}
457
+ # ### Response:
458
+ # """
459
+ # # Tokenize and generate the response
460
+ # model = self.model_loader.load()
461
+ # tokenizer = self.model_loader._tokenizer
462
+ # inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
463
+ # outputs = model.generate(
464
+ # inputs.input_ids,
465
+ # num_return_sequences=1,
466
+ # temperature=0.7,
467
+ # top_p=0.9,
468
+ # do_sample=True
469
+ # )
470
+ # response = tokenizer.decode(outputs[0], skip_special_tokens=True)
471
+ # logging.info(f"Model response: {response}")
472
+ # # Parse and normalize the JSON output
473
+ # json_start = response.find("{")
474
+ # json_end = response.rfind("}") + 1
475
+ # if json_start == -1 or json_end == -1:
476
+ # raise ValueError("No JSON found in the model response.")
477
+ # # Extract the JSON substring
478
+ # structured_data = json.loads(response[json_start:json_end])
479
+ # # Normalize the JSON output
480
+ # normalized_data = self.normalize_json_output(structured_data, default_schema)
481
+ # # Ensure blood pressure is a string
482
+ # if normalized_data["vitals"]["blood_pressure"] and isinstance(normalized_data["vitals"]["blood_pressure"], str):
483
+ # normalized_data["vitals"]["blood_pressure"] = normalized_data["vitals"]["blood_pressure"].strip('"')
484
+ # return json.dumps(normalized_data)
485
+ # except json.JSONDecodeError as e:
486
+ # logging.error(f"JSON parsing failed: {e}")
487
+ # return json.dumps({"error": f"Failed to parse JSON: {str(e)}"})
488
+ # except Exception as e:
489
+ # logging.error(f"Error extracting medical data: {e}")
490
+ # return json.dumps({"error": f"Failed to extract medical data: {str(e)}"})
491
+
492
+ # @staticmethod
493
+ # def normalize_json_output(model_output, default_schema):
494
+ # """
495
+ # Normalize the model's JSON output to match the default schema.
496
+ # """
497
+ # try:
498
+ # normalized_output = default_schema.copy()
499
+ # for key in normalized_output:
500
+ # if key in model_output:
501
+ # normalized_output[key] = model_output[key]
502
+ # return normalized_output
503
+ # except Exception as e:
504
+ # logging.error(f"Failed to normalize JSON: {e}")
505
+ # return default_schema # Return the default schema in case of errors
506
+
507
+ # # Initialize lazy loaders
508
+ # medalpaca_model_loader = LazyModelLoader("lmsys/vicuna-7b-v1.5", "text-generation")
509
+ # summarization_model_loader = LazyModelLoader("google-t5/t5-small", "summarization")
510
+ # whisper_model = whisper.load_model("tiny")
511
+
512
+ # # Initialize knowledge base
513
+ # medical_documents = [
514
+ # "Hypertension is a chronic condition characterized by elevated blood pressure.",
515
+ # "Diabetes is a metabolic disorder that affects blood sugar levels.",
516
+ # "Common symptoms of chest pain include pressure, tightness, or discomfort in the chest."
517
+ # ]
518
+ # knowledge_base = KnowledgeBase(medical_documents)
519
+
520
+ # # Initialize agents
521
+ # text_extractor_agent = TextExtractorAgent()
522
+ # phi_scrubber_agent = PHIScrubberAgent()
523
+ # medical_data_extractor_agent = MedicalDataExtractorAgent(medalpaca_model_loader, knowledge_base)
524
+ # summarizer_agent = SummarizerAgent(summarization_model_loader)
525
+
526
+ # # API Endpoints
527
+ # @app.route('/api/extract_medical_data', methods=['POST'])
528
+ # def extract_medical_data():
529
+ # """Extract structured medical data from raw text."""
530
+ # try:
531
+ # data = request.json
532
+ # if "text" not in data or not data["text"].strip():
533
+ # return jsonify({"error": "No valid text provided"}), 400
534
+ # raw_text = data["text"]
535
+ # clean_text = phi_scrubber_agent.scrub_phi(raw_text)
536
+ # structured_data = medical_data_extractor_agent.extract_medical_data(clean_text)
537
+ # return jsonify(json.loads(structured_data)), 200
538
+ # except Exception as e:
539
+ # logging.error(f"Failed to extract medical data: {e}")
540
+ # return jsonify({"error": f"Extraction Error: {str(e)}"}), 500
541
+
542
+ # @app.route('/api/transcribe', methods=['POST'])
543
+ # def transcribe_audio():
544
+ # """Transcribe audio files into text."""
545
+ # if 'audio' not in request.files:
546
+ # abort(400, description="No audio file provided")
547
+ # audio_file = request.files['audio']
548
+ # if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
549
+ # abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
550
+ # filename = secure_filename(audio_file.filename)
551
+ # audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
552
+ # audio_file.save(audio_path)
553
+ # try:
554
+ # result = whisper_model.transcribe(audio_path)
555
+ # transcribed_text = result["text"]
556
+ # os.remove(audio_path)
557
+ # return jsonify({"transcribed_text": transcribed_text}), 200
558
+ # except Exception as e:
559
+ # logging.error(f"Transcription failed: {str(e)}")
560
+ # return jsonify({"error": f"Transcription failed: {str(e)}"}), 500
561
+
562
+ # @app.route('/api/generate_summary', methods=['POST'])
563
+ # def generate_summary():
564
+ # """Generate a summary from the provided text."""
565
+ # data = request.json
566
+ # if "text" not in data or not data["text"].strip():
567
+ # return jsonify({"error": "No valid text provided"}), 400
568
+ # context = data["text"]
569
+ # clean_text = phi_scrubber_agent.scrub_phi(context)
570
+ # summary = summarizer_agent.generate_summary(clean_text)
571
+ # return jsonify({"summary": summary}), 200
572
+
573
+ # @app.route('/api/extract_medical_data_from_audio', methods=['POST'])
574
+ # def extract_medical_data_from_audio():
575
+ # """Extract medical data from transcribed audio."""
576
+ # if 'audio' not in request.files:
577
+ # abort(400, description="No audio file provided")
578
+ # audio_file = request.files['audio']
579
+ # if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
580
+ # abort(400, description="Invalid file type. Only mp3, wav, and flac files are allowed.")
581
+ # filename = secure_filename(audio_file.filename)
582
+ # audio_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
583
+ # audio_file.save(audio_path)
584
+ # try:
585
+ # result = whisper_model.transcribe(audio_path)
586
+ # transcribed_text = result["text"]
587
+ # clean_text = phi_scrubber_agent.scrub_phi(transcribed_text)
588
+ # summary = summarizer_agent.generate_summary(transcribed_text)
589
+ # structured_data = medical_data_extractor_agent.extract_medical_data(transcribed_text)
590
+ # response = {
591
+ # "transcribed_text": transcribed_text,
592
+ # "summary": summary,
593
+ # "medical_chart": json.loads(structured_data)
594
+ # }
595
+ # os.remove(audio_path)
596
+ # return jsonify(response), 200
597
+ # except Exception as e:
598
+ # logging.error(f"Processing failed: {str(e)}")
599
+ # return jsonify({"error": f"Processing failed: {str(e)}"}), 500
600
+
601
+ # @app.route('/upload_document', methods=['POST'])
602
+ # def upload_document():
603
+ # """Upload and extract text from documents."""
604
+ # if 'file' not in request.files:
605
+ # return jsonify({"error": "No file uploaded"}), 400
606
+ # file = request.files['file']
607
+ # if file.filename == '':
608
+ # return jsonify({"error": "No file selected"}), 400
609
+ # if file and allowed_file(file.filename, ALLOWED_DOCUMENT_EXTENSIONS):
610
+ # filename = secure_filename(file.filename)
611
+ # filepath = os.path.join(UPLOAD_DIR, filename)
612
+ # file.save(filepath)
613
+ # ext = filename.rsplit('.', 1)[1].lower()
614
+ # extracted_text = text_extractor_agent.extract_text(filepath, ext)
615
+ # if not extracted_text:
616
+ # return jsonify({"error": "No text found in file."}), 400
617
+ # response_data = {
618
+ # "file": filename,
619
+ # "extracted_text": extracted_text[:500],
620
+ # "message": "Click to extract medical terms"
621
+ # }
622
+ # os.remove(filepath)
623
+ # return jsonify(response_data), 200
624
+ # return jsonify({"error": "Invalid file type"}), 400
625
+
626
+ # @app.route('/extract_medical_data_from_document', methods=['POST'])
627
+ # def extract_medical_data_from_document():
628
+ # """Extract medical data from document text."""
629
+ # data = request.json
630
+ # if "text" not in data or not data["text"].strip():
631
+ # return jsonify({"error": "No valid text provided"}), 400
632
+ # context = data["text"]
633
+ # clean_text = phi_scrubber_agent.scrub_phi(context)
634
+ # structured_data = medical_data_extractor_agent.extract_medical_data(clean_text)
635
+ # return jsonify(json.loads(structured_data)), 200
636
+
637
+ # if __name__ == '__main__':
638
+ # app.run(host='0.0.0.0', port=5000, debug=True)