HNTAI / ai_med_extract /agents /medical_data_extractor.py
Joyna-Joy
Revert "changes in med_data_extraction"
0cc963e
raw
history blame
4.62 kB
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
import json
import torch
from ai_med_extract.agents.phi_scrubber import MedicalTextUtils
class MedicalDataExtractorAgent:
def __init__(self, generator):
self.generator = generator
def extract_medical_data(self, text):
try:
if not text:
return {"error": "No text provided for extraction"}
# Ensure we're using float32 precision
with torch.cuda.device(0) if torch.cuda.is_available() else torch.no_grad():
prompt = (
"Extract structured medical information from the following clinical note.\n\n"
"Return the result in JSON format with the following fields:\n"
"patient_condition, symptoms, current_problems, allergies, dr_notes, "
"prescription, investigations, follow_up_instructions.\n\n"
f"Clinical Note:\n{text}\n\n"
"Structured JSON Output:\n"
)
# Generate response with proper error handling
try:
response = self.generator(
prompt,
max_new_tokens=256,
do_sample=False, # Disable sampling for more consistent results
temperature=0.1, # Lower temperature for more focused output
num_return_sequences=1
)[0]["generated_text"]
except Exception as e:
logging.error(f"Model generation failed: {str(e)}")
return {"error": f"Failed to generate medical data: {str(e)}"}
logging.debug(f"Raw model output: {response}")
# Extract JSON from response
try:
json_start = response.find("{")
json_end = response.rfind("}") + 1
if json_start == -1 or json_end == -1:
raise ValueError("No JSON found in the model response.")
json_str = response[json_start:json_end]
return json.loads(json_str)
except json.JSONDecodeError as e:
logging.error(f"JSON parsing failed: {str(e)}")
return {"error": f"Failed to parse medical data: {str(e)}"}
except Exception as e:
logging.error(f"JSON extraction failed: {str(e)}")
return {"error": f"Failed to extract medical data: {str(e)}"}
except Exception as e:
logging.error(f"Error extracting medical data: {str(e)}")
return {"error": f"Failed to extract medical data: {str(e)}"}
class MedicalDocDataExtractorAgent:
def __init__(self, generator):
self.generator = generator
def extract_from_text(self, text, tokenizer):
if not text:
return {"error": "No text provided for extraction"}
try:
chunks = MedicalTextUtils.Chunker.chunk_text(text, tokenizer)
all_extracted = []
with ThreadPoolExecutor(max_workers=4) as executor:
futures = {
executor.submit(MedicalTextUtils.Processor.process_chunk, self.generator, chunk, idx): idx
for idx, chunk in enumerate(chunks)
}
for future in as_completed(futures):
idx, output = future.result()
if not output:
continue
try:
objs = MedicalTextUtils.JSONExtractor.extract_objects(output)
if objs:
all_extracted.extend(objs)
except Exception as e:
logging.error(f"Failed to extract objects from chunk {idx}: {str(e)}")
continue
if all_extracted:
deduped = MedicalTextUtils.Deduplicator.deduplicate(all_extracted)
grouped = MedicalTextUtils.Grouper.group_by_category(deduped)
cleaned = MedicalTextUtils.Cleaner.clean(grouped)
return cleaned
else:
return {"error": "No valid data extracted"}
except Exception as e:
logging.error(f"Error in extract_from_text: {str(e)}")
return {"error": f"Failed to extract medical data: {str(e)}"}