from concurrent.futures import ThreadPoolExecutor, as_completed import logging import json import torch from ai_med_extract.agents.phi_scrubber import MedicalTextUtils class MedicalDataExtractorAgent: def __init__(self, generator): self.generator = generator def extract_medical_data(self, text): try: if not text: return {"error": "No text provided for extraction"} # Ensure we're using float32 precision with torch.cuda.device(0) if torch.cuda.is_available() else torch.no_grad(): prompt = ( "Extract structured medical information from the following clinical note.\n\n" "Return the result in JSON format with the following fields:\n" "patient_condition, symptoms, current_problems, allergies, dr_notes, " "prescription, investigations, follow_up_instructions.\n\n" f"Clinical Note:\n{text}\n\n" "Structured JSON Output:\n" ) # Generate response with proper error handling try: response = self.generator( prompt, max_new_tokens=256, do_sample=False, # Disable sampling for more consistent results temperature=0.1, # Lower temperature for more focused output num_return_sequences=1 )[0]["generated_text"] except Exception as e: logging.error(f"Model generation failed: {str(e)}") return {"error": f"Failed to generate medical data: {str(e)}"} logging.debug(f"Raw model output: {response}") # Extract JSON from response try: json_start = response.find("{") json_end = response.rfind("}") + 1 if json_start == -1 or json_end == -1: raise ValueError("No JSON found in the model response.") json_str = response[json_start:json_end] return json.loads(json_str) except json.JSONDecodeError as e: logging.error(f"JSON parsing failed: {str(e)}") return {"error": f"Failed to parse medical data: {str(e)}"} except Exception as e: logging.error(f"JSON extraction failed: {str(e)}") return {"error": f"Failed to extract medical data: {str(e)}"} except Exception as e: logging.error(f"Error extracting medical data: {str(e)}") return {"error": f"Failed to extract medical data: {str(e)}"} class MedicalDocDataExtractorAgent: def __init__(self, generator): self.generator = generator def extract_from_text(self, text, tokenizer): if not text: return {"error": "No text provided for extraction"} try: chunks = MedicalTextUtils.Chunker.chunk_text(text, tokenizer) all_extracted = [] with ThreadPoolExecutor(max_workers=4) as executor: futures = { executor.submit(MedicalTextUtils.Processor.process_chunk, self.generator, chunk, idx): idx for idx, chunk in enumerate(chunks) } for future in as_completed(futures): idx, output = future.result() if not output: continue try: objs = MedicalTextUtils.JSONExtractor.extract_objects(output) if objs: all_extracted.extend(objs) except Exception as e: logging.error(f"Failed to extract objects from chunk {idx}: {str(e)}") continue if all_extracted: deduped = MedicalTextUtils.Deduplicator.deduplicate(all_extracted) grouped = MedicalTextUtils.Grouper.group_by_category(deduped) cleaned = MedicalTextUtils.Cleaner.clean(grouped) return cleaned else: return {"error": "No valid data extracted"} except Exception as e: logging.error(f"Error in extract_from_text: {str(e)}") return {"error": f"Failed to extract medical data: {str(e)}"}