dev-3 commited on
Commit
93dd654
·
1 Parent(s): 7814a5f
ai_med_extract/api/routes.py CHANGED
@@ -107,15 +107,36 @@ def run_qa_pipeline(qa_pipeline, question, context):
107
  raise
108
 
109
  def get_ner_pipeline(ner_model_type, ner_model_name):
 
 
 
110
  if not hasattr(get_ner_pipeline, "cache"):
111
  get_ner_pipeline.cache = {}
 
112
  key = (ner_model_type, ner_model_name)
113
  if key not in get_ner_pipeline.cache:
114
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- get_ner_pipeline.cache[key] = pipeline(
117
- task=ner_model_type, model=ner_model_name, trust_remote_code=True
118
- )
119
  return get_ner_pipeline.cache[key]
120
 
121
 
@@ -201,34 +222,45 @@ def register_routes(app, agents):
201
  try:
202
  file.save(filepath)
203
  except Exception as e:
204
- return jsonify({"error": f"Failed to save file: {str(e)}"}), 500
205
  ext = filename.rsplit(".", 1)[-1].lower()
206
- extracted_text = TextExtractorAgent.extract_text(filepath, ext)
207
- if not extracted_text or extracted_text == "No text found":
208
- return (
209
- jsonify({"error": f"Failed to extract text from {filename}"}),
210
- 415,
211
- )
 
 
 
 
 
 
 
212
  skip_medical_check = (
213
  request.form.get("skip_medical_check", "false").lower() == "true"
214
  )
215
  if not skip_medical_check:
216
- ner_results = ner_pipeline(extracted_text)
217
- medical_entities = list(
218
- set(
219
- [
220
- r["word"]
221
- for r in ner_results
222
- if r["entity"].startswith("B-")
223
- or r["entity"].startswith("I-")
224
- ]
225
- )
226
- )
227
- if not medical_entities:
228
- return (
229
- jsonify({"error": f"'{filename}' is not medically relevant"}),
230
- 406,
231
  )
 
 
 
 
 
 
 
 
232
  skip_patient_check = (
233
  request.form.get("skip_patient_check", "false").lower() == "true"
234
  )
@@ -260,7 +292,8 @@ def register_routes(app, agents):
260
  else:
261
  summary = str(summary_result)
262
  except Exception as e:
263
- summary = "Summary failed"
 
264
  extracted_data.append(
265
  {
266
  "file": filename,
@@ -378,37 +411,98 @@ def register_routes(app, agents):
378
 
379
  @app.route("/api/extract_medical_data_from_audio", methods=["POST"])
380
  def extract_medical_data_from_audio():
381
- if "audio" not in request.files:
382
- return jsonify({"error": "No audio file provided"}), 400
383
- audio_file = request.files["audio"]
384
- if audio_file.filename == "":
385
- return jsonify({"error": "No selected audio file"}), 400
386
- temp_path = os.path.join("/tmp", audio_file.filename)
387
- audio_file.save(temp_path)
388
  try:
389
- result = whisper_model.transcribe(temp_path)
390
- transcribed_text = result["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  try:
392
- clean_text = PHIScrubberAgent.scrub_phi(transcribed_text)
393
- except Exception:
394
- clean_text = transcribed_text
395
- summary = SummarizerAgent.generate_summary(clean_text)
396
- medical_data = MedicalDataExtractorAgent.extract_medical_data(clean_text)
397
- os.remove(temp_path)
398
- return (
399
- jsonify(
400
- {
401
- "transcribed_text": clean_text,
402
- "summary": summary,
403
- "medical_chart": medical_data,
404
- }
405
- ),
406
- 200,
407
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  except Exception as e:
409
- if os.path.exists(temp_path):
410
- os.remove(temp_path)
411
- return jsonify({"error": f"Processing failed: {str(e)}"}), 500
 
 
412
 
413
  @app.route("/extract_medical_data_questions", methods=["POST"])
414
  def extract_medical_data_questions():
 
107
  raise
108
 
109
  def get_ner_pipeline(ner_model_type, ner_model_name):
110
+ if not ner_model_type or not ner_model_name:
111
+ raise ValueError("Both ner_model_type and ner_model_name must be provided")
112
+
113
  if not hasattr(get_ner_pipeline, "cache"):
114
  get_ner_pipeline.cache = {}
115
+
116
  key = (ner_model_type, ner_model_name)
117
  if key not in get_ner_pipeline.cache:
118
+ try:
119
+ from transformers import pipeline
120
+ logging.info(f"Loading NER pipeline - Type: {ner_model_type}, Model: {ner_model_name}")
121
+
122
+ get_ner_pipeline.cache[key] = pipeline(
123
+ task=ner_model_type,
124
+ model=ner_model_name,
125
+ trust_remote_code=True,
126
+ device_map="auto"
127
+ )
128
+ logging.info(f"Successfully loaded NER pipeline for {ner_model_name}")
129
+ except Exception as e:
130
+ logging.error(f"Failed to load NER pipeline: {str(e)}", exc_info=True)
131
+ if "Connection" in str(e):
132
+ raise RuntimeError(f"Network error while loading model: {str(e)}")
133
+ elif "CUDA" in str(e):
134
+ raise RuntimeError(f"GPU error while loading model: {str(e)}")
135
+ elif "disk space" in str(e):
136
+ raise RuntimeError(f"Insufficient disk space: {str(e)}")
137
+ else:
138
+ raise RuntimeError(f"Error loading model: {str(e)}")
139
 
 
 
 
140
  return get_ner_pipeline.cache[key]
141
 
142
 
 
222
  try:
223
  file.save(filepath)
224
  except Exception as e:
225
+ return jsonify({"error": f"Filed to save file: {str(e)}"}), 500
226
  ext = filename.rsplit(".", 1)[-1].lower()
227
+ try:
228
+ extracted_text = TextExtractorAgent.extract_text(filepath, ext)
229
+ if not extracted_text or extracted_text == "No text found":
230
+ os.remove(filepath) # Clean up on failure
231
+ return (
232
+ jsonify({"error": f"Failed to extract text from {filename}"}),
233
+ 415,
234
+ )
235
+ except Exception as e:
236
+ logging.error(f"Text extraction failed for {filename}: {str(e)}", exc_info=True)
237
+ os.remove(filepath) # Clean up on failure
238
+ return jsonify({"error": f"Text extraction failed: {str(e)}"}), 500
239
+
240
  skip_medical_check = (
241
  request.form.get("skip_medical_check", "false").lower() == "true"
242
  )
243
  if not skip_medical_check:
244
+ try:
245
+ ner_results = ner_pipeline(extracted_text)
246
+ medical_entities = list(
247
+ set(
248
+ [
249
+ r["word"]
250
+ for r in ner_results
251
+ if r["entity"].startswith("B-")
252
+ or r["entity"].startswith("I-")
253
+ ]
254
+ )
 
 
 
 
255
  )
256
+ if not medical_entities:
257
+ return (
258
+ jsonify({"error": f"'{filename}' is not medically relevant"}),
259
+ 406,
260
+ )
261
+ except Exception as e:
262
+ logging.error(f"NER processing failed for {filename}: {str(e)}", exc_info=True)
263
+ return jsonify({"error": f"NER processing failed: {str(e)}"}), 500
264
  skip_patient_check = (
265
  request.form.get("skip_patient_check", "false").lower() == "true"
266
  )
 
292
  else:
293
  summary = str(summary_result)
294
  except Exception as e:
295
+ logging.error(f"Summary generation failed for {filename}: {str(e)}", exc_info=True)
296
+ summary = f"Summary generation failed: {str(e)}"
297
  extracted_data.append(
298
  {
299
  "file": filename,
 
411
 
412
  @app.route("/api/extract_medical_data_from_audio", methods=["POST"])
413
  def extract_medical_data_from_audio():
414
+ temp_path = None
 
 
 
 
 
 
415
  try:
416
+ # Validate request
417
+ if "audio" not in request.files:
418
+ return jsonify({"error": "No audio file provided"}), 400
419
+
420
+ audio_file = request.files["audio"]
421
+ if audio_file.filename == "":
422
+ return jsonify({"error": "No selected audio file"}), 400
423
+
424
+ # Validate file extension
425
+ if not allowed_file(audio_file.filename):
426
+ return jsonify({"error": f"Unsupported audio format. Allowed formats: wav, mp3, m4a, ogg"}), 400
427
+
428
+ # Check file size
429
+ valid_size, error_message = check_file_size(audio_file)
430
+ if not valid_size:
431
+ return jsonify({"error": error_message}), 400
432
+
433
+ # Generate a secure temporary path
434
+ import uuid
435
+ from werkzeug.utils import secure_filename
436
+ temp_filename = f"{uuid.uuid4()}_{secure_filename(audio_file.filename)}"
437
+ temp_path = os.path.join("/tmp", temp_filename)
438
+
439
  try:
440
+ logging.info(f"Saving audio file to temporary path: {temp_path}")
441
+ audio_file.save(temp_path)
442
+
443
+ # Initialize whisper model with retries
444
+ max_retries = 3
445
+ for attempt in range(max_retries):
446
+ try:
447
+ logging.info(f"Initializing Whisper model (attempt {attempt + 1}/{max_retries})")
448
+ transcribed_text = whisper_model.transcribe(temp_path)["text"]
449
+ if not transcribed_text:
450
+ raise ValueError("No text output from transcription")
451
+ logging.info("Audio transcription successful")
452
+ break
453
+ except Exception as e:
454
+ if attempt == max_retries - 1: # Last attempt
455
+ raise
456
+ logging.warning(f"Transcription attempt {attempt + 1} failed: {str(e)}")
457
+ continue
458
+
459
+ # Clean and process text
460
+ try:
461
+ logging.info("Scrubbing PHI from transcribed text")
462
+ clean_text = PHIScrubberAgent.scrub_phi(transcribed_text)
463
+ except Exception as e:
464
+ logging.warning(f"PHI scrubbing failed, using raw text: {str(e)}")
465
+ clean_text = transcribed_text
466
+
467
+ try:
468
+ logging.info("Generating summary")
469
+ summary = SummarizerAgent.generate_summary(clean_text)
470
+ except Exception as e:
471
+ logging.error(f"Summary generation failed: {str(e)}")
472
+ summary = "Summary generation failed"
473
+
474
+ try:
475
+ logging.info("Extracting medical data")
476
+ medical_data = MedicalDataExtractorAgent.extract_medical_data(clean_text)
477
+ except Exception as e:
478
+ logging.error(f"Medical data extraction failed: {str(e)}")
479
+ medical_data = {"error": f"Medical data extraction failed: {str(e)}"}
480
+
481
+ # Clean up temporary file
482
+ if os.path.exists(temp_path):
483
+ os.remove(temp_path)
484
+
485
+ return jsonify({
486
+ "transcribed_text": clean_text,
487
+ "summary": summary,
488
+ "medical_chart": medical_data,
489
+ }), 200
490
+
491
+ except Exception as e:
492
+ logging.error(f"Audio processing failed: {str(e)}", exc_info=True)
493
+ if os.path.exists(temp_path):
494
+ os.remove(temp_path)
495
+ return jsonify({
496
+ "error": f"Audio processing failed: {str(e)}",
497
+ "details": "Error occurred during audio transcription or text processing"
498
+ }), 500
499
+
500
  except Exception as e:
501
+ logging.error(f"Request handling failed: {str(e)}", exc_info=True)
502
+ return jsonify({
503
+ "error": "Internal server error",
504
+ "details": str(e)
505
+ }), 500
506
 
507
  @app.route("/extract_medical_data_questions", methods=["POST"])
508
  def extract_medical_data_questions():
ai_med_extract/app.py CHANGED
@@ -20,22 +20,67 @@ load_dotenv()
20
  app = Flask(__name__)
21
  CORS(app)
22
 
 
23
  UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads'))
24
- os.makedirs(UPLOAD_DIR, exist_ok=True)
 
 
 
 
 
 
 
25
  app.config['UPLOAD_FOLDER'] = UPLOAD_DIR
26
- app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max file size
27
 
28
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
29
 
30
- # Model loaders (example, adjust as needed)
31
- medalpaca_model_loader = None # TODO: Implement LazyModelLoader if needed
32
- summarization_model_loader = None # TODO: Implement LazyModelLoader if needed
33
- whisper_model = None
34
- def get_whisper_model():
35
- global whisper_model
36
- if whisper_model is None:
37
- whisper_model = whisper.load_model("tiny")
38
- return whisper_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Initialize agents
41
  text_extractor_agent = TextExtractorAgent()
@@ -49,7 +94,7 @@ agents = {
49
  "phi_scrubber": phi_scrubber_agent,
50
  "summarizer": summarizer_agent,
51
  "medical_data_extractor": medical_data_extractor_agent,
52
- "whisper_model": get_whisper_model
53
  }
54
 
55
  from .api.routes import register_routes
 
20
  app = Flask(__name__)
21
  CORS(app)
22
 
23
+ # Configure upload directory
24
  UPLOAD_DIR = os.getenv('UPLOAD_DIR', os.path.join(os.getcwd(), 'uploads'))
25
+ try:
26
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
27
+ os.chmod(UPLOAD_DIR, 0o777) # Ensure directory is writable
28
+ except Exception as e:
29
+ logging.error(f"Failed to create/configure upload directory: {e}", exc_info=True)
30
+ UPLOAD_DIR = '/tmp/uploads' # Fallback to /tmp if main directory creation fails
31
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
32
+
33
  app.config['UPLOAD_FOLDER'] = UPLOAD_DIR
34
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 * 1024 # 16 GB max file size to handle large medical files
35
 
36
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
37
 
38
+ # Model loaders
39
+ class LazyModelLoader:
40
+ def __init__(self, model_name, model_type):
41
+ self.model_name = model_name
42
+ self.model_type = model_type
43
+ self._model = None
44
+
45
+ def load(self):
46
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
47
+ if self._model is None:
48
+ self._model = pipeline(
49
+ task=self.model_type,
50
+ model=self.model_name,
51
+ trust_remote_code=True,
52
+ device_map="auto"
53
+ )
54
+ return self._model
55
+
56
+ medalpaca_model_loader = LazyModelLoader("medalpaca/medalpaca-13b", "text-generation")
57
+ summarization_model_loader = LazyModelLoader("facebook/bart-large-cnn", "summarization")
58
+ class WhisperModelLoader:
59
+ _instance = None
60
+
61
+ def __init__(self):
62
+ self._model = None
63
+
64
+ @staticmethod
65
+ def get_instance():
66
+ if WhisperModelLoader._instance is None:
67
+ WhisperModelLoader._instance = WhisperModelLoader()
68
+ return WhisperModelLoader._instance
69
+
70
+ def load(self):
71
+ if self._model is None:
72
+ try:
73
+ logging.info("Loading Whisper model...")
74
+ self._model = whisper.load_model("base")
75
+ logging.info("Whisper model loaded successfully")
76
+ except Exception as e:
77
+ logging.error(f"Failed to load Whisper model: {str(e)}", exc_info=True)
78
+ raise RuntimeError(f"Failed to load Whisper model: {str(e)}")
79
+ return self._model
80
+
81
+ def transcribe(self, audio_path):
82
+ model = self.load()
83
+ return model.transcribe(audio_path)
84
 
85
  # Initialize agents
86
  text_extractor_agent = TextExtractorAgent()
 
94
  "phi_scrubber": phi_scrubber_agent,
95
  "summarizer": summarizer_agent,
96
  "medical_data_extractor": medical_data_extractor_agent,
97
+ "whisper_model": WhisperModelLoader.get_instance()
98
  }
99
 
100
  from .api.routes import register_routes
ai_med_extract/utils/file_utils.py CHANGED
@@ -5,7 +5,7 @@ import logging
5
  from werkzeug.utils import secure_filename
6
  from flask import current_app
7
 
8
- ALLOWED_EXTENSIONS = {"pdf", "jpg", "jpeg", "png", "svg", "docx", "doc", "xlsx", "xls"}
9
  MAX_SIZE_PDF_DOCS = 1 * 1024 * 1024 * 1024 # 1GB
10
  MAX_SIZE_IMAGES = 500 * 1024 * 1024 # 500MB
11
 
@@ -15,15 +15,26 @@ def allowed_file(filename):
15
 
16
 
17
  def check_file_size(file):
18
- file.seek(0, os.SEEK_END)
19
- size = file.tell()
20
- file.seek(0)
21
- extension = file.filename.rsplit('.', 1)[-1].lower()
22
- if extension in {"pdf", "docx"} and size > MAX_SIZE_PDF_DOCS:
23
- return False, f"File {file.filename} exceeds 1GB size limit"
24
- elif extension in {"jpg", "jpeg", "png"} and size > MAX_SIZE_IMAGES:
25
- return False, f"Image {file.filename} exceeds 500MB size limit"
26
- return True, None
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  def save_data_to_storage(filename, data):
 
5
  from werkzeug.utils import secure_filename
6
  from flask import current_app
7
 
8
+ ALLOWED_EXTENSIONS = {"pdf", "jpg", "jpeg", "png", "svg", "docx", "doc", "xlsx", "xls", "wav", "mp3", "m4a", "ogg"}
9
  MAX_SIZE_PDF_DOCS = 1 * 1024 * 1024 * 1024 # 1GB
10
  MAX_SIZE_IMAGES = 500 * 1024 * 1024 # 500MB
11
 
 
15
 
16
 
17
  def check_file_size(file):
18
+ try:
19
+ # Store current position
20
+ current_pos = file.tell()
21
+
22
+ # Check size
23
+ file.seek(0, os.SEEK_END)
24
+ size = file.tell()
25
+
26
+ # Return to original position
27
+ file.seek(current_pos)
28
+
29
+ extension = file.filename.rsplit('.', 1)[-1].lower()
30
+ if extension in {"pdf", "docx"} and size > MAX_SIZE_PDF_DOCS:
31
+ return False, f"File {file.filename} exceeds 1GB size limit"
32
+ elif extension in {"jpg", "jpeg", "png"} and size > MAX_SIZE_IMAGES:
33
+ return False, f"Image {file.filename} exceeds 500MB size limit"
34
+ return True, None
35
+ except Exception as e:
36
+ logging.error(f"Error checking file size: {e}", exc_info=True)
37
+ return False, f"Error checking file size: {str(e)}"
38
 
39
 
40
  def save_data_to_storage(filename, data):