Spaces:
Paused
Paused
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import logging | |
| import json | |
| import torch | |
| from ai_med_extract.agents.phi_scrubber import MedicalTextUtils | |
| class MedicalDataExtractorAgent: | |
| def __init__(self, generator): | |
| self.generator = generator | |
| def extract_medical_data(self, text): | |
| try: | |
| if not text: | |
| return {"error": "No text provided for extraction"} | |
| # Ensure we're using float32 precision | |
| with torch.cuda.device(0) if torch.cuda.is_available() else torch.no_grad(): | |
| prompt = ( | |
| "Extract structured medical information from the following clinical note.\n\n" | |
| "Return the result in JSON format with the following fields:\n" | |
| "patient_condition, symptoms, current_problems, allergies, dr_notes, " | |
| "prescription, investigations, follow_up_instructions.\n\n" | |
| f"Clinical Note:\n{text}\n\n" | |
| "Structured JSON Output:\n" | |
| ) | |
| # Generate response with proper error handling | |
| try: | |
| response = self.generator( | |
| prompt, | |
| max_new_tokens=256, | |
| do_sample=False, # Disable sampling for more consistent results | |
| temperature=0.1, # Lower temperature for more focused output | |
| num_return_sequences=1 | |
| )[0]["generated_text"] | |
| except Exception as e: | |
| logging.error(f"Model generation failed: {str(e)}") | |
| return {"error": f"Failed to generate medical data: {str(e)}"} | |
| logging.debug(f"Raw model output: {response}") | |
| # Extract JSON from response | |
| try: | |
| json_start = response.find("{") | |
| json_end = response.rfind("}") + 1 | |
| if json_start == -1 or json_end == -1: | |
| raise ValueError("No JSON found in the model response.") | |
| json_str = response[json_start:json_end] | |
| return json.loads(json_str) | |
| except json.JSONDecodeError as e: | |
| logging.error(f"JSON parsing failed: {str(e)}") | |
| return {"error": f"Failed to parse medical data: {str(e)}"} | |
| except Exception as e: | |
| logging.error(f"JSON extraction failed: {str(e)}") | |
| return {"error": f"Failed to extract medical data: {str(e)}"} | |
| except Exception as e: | |
| logging.error(f"Error extracting medical data: {str(e)}") | |
| return {"error": f"Failed to extract medical data: {str(e)}"} | |
| class MedicalDocDataExtractorAgent: | |
| def __init__(self, generator): | |
| self.generator = generator | |
| def extract_from_text(self, text, tokenizer): | |
| if not text: | |
| return {"error": "No text provided for extraction"} | |
| try: | |
| chunks = MedicalTextUtils.Chunker.chunk_text(text, tokenizer) | |
| all_extracted = [] | |
| with ThreadPoolExecutor(max_workers=4) as executor: | |
| futures = { | |
| executor.submit(MedicalTextUtils.Processor.process_chunk, self.generator, chunk, idx): idx | |
| for idx, chunk in enumerate(chunks) | |
| } | |
| for future in as_completed(futures): | |
| idx, output = future.result() | |
| if not output: | |
| continue | |
| try: | |
| objs = MedicalTextUtils.JSONExtractor.extract_objects(output) | |
| if objs: | |
| all_extracted.extend(objs) | |
| except Exception as e: | |
| logging.error(f"Failed to extract objects from chunk {idx}: {str(e)}") | |
| continue | |
| if all_extracted: | |
| deduped = MedicalTextUtils.Deduplicator.deduplicate(all_extracted) | |
| grouped = MedicalTextUtils.Grouper.group_by_category(deduped) | |
| cleaned = MedicalTextUtils.Cleaner.clean(grouped) | |
| return cleaned | |
| else: | |
| return {"error": "No valid data extracted"} | |
| except Exception as e: | |
| logging.error(f"Error in extract_from_text: {str(e)}") | |
| return {"error": f"Failed to extract medical data: {str(e)}"} | |