Spaces:
Paused
Paused
| import os | |
| import logging | |
| from flask import Flask, jsonify | |
| from flask_cors import CORS | |
| import whisper | |
| from dotenv import load_dotenv | |
| from .agents.text_extractor import TextExtractorAgent | |
| from .agents.phi_scrubber import PHIScrubberAgent | |
| from .agents.phi_scrubber import MedicalTextUtils | |
| from .agents.summarizer import SummarizerAgent | |
| from .agents.medical_data_extractor import MedicalDataExtractorAgent | |
| from .agents.medical_data_extractor import MedicalDocDataExtractorAgent | |
| from .agents.patient_summary_agent import PatientSummarizerAgent | |
| import torch | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| handlers=[ | |
| logging.StreamHandler(), | |
| logging.FileHandler('/tmp/app.log') | |
| ] | |
| ) | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Configure upload directory with safe fallbacks (avoid creating /data at import time) | |
| def _resolve_upload_dir() -> str: | |
| try: | |
| # Prefer /data/uploads if it already exists and is writable | |
| data_dir = '/data/uploads' | |
| if os.path.isdir('/data') and (os.path.isdir(data_dir) or os.access('/data', os.W_OK)): | |
| os.makedirs(data_dir, exist_ok=True) | |
| return data_dir | |
| except Exception: | |
| pass | |
| # Fallback to /tmp/uploads which is always writable on Spaces | |
| tmp_dir = '/tmp/uploads' | |
| os.makedirs(tmp_dir, exist_ok=True) | |
| return tmp_dir | |
| app.config['UPLOAD_FOLDER'] = _resolve_upload_dir() | |
| app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 # 100 MB max file size | |
| # Set cache directories | |
| CACHE_DIRS = { | |
| 'HF_HOME': '/tmp/huggingface', | |
| 'HF_HOME': '/tmp/huggingface', | |
| 'XDG_CACHE_HOME': '/tmp', | |
| 'TORCH_HOME': '/tmp/torch', | |
| 'WHISPER_CACHE': '/tmp/whisper' | |
| } | |
| for env_var, path in CACHE_DIRS.items(): | |
| os.environ[env_var] = path | |
| os.makedirs(path, exist_ok=True) | |
| # Model loaders | |
| class LazyModelLoader: | |
| def __init__(self, model_name, model_type, fallback_model=None, max_retries=2): | |
| self.model_name = model_name | |
| self.model_type = model_type | |
| self.fallback_model = fallback_model | |
| self._model = None | |
| self._tokenizer = None | |
| self._pipeline = None | |
| self._retries = 0 | |
| self.max_retries = max_retries | |
| def load(self): | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM | |
| import torch | |
| if self._pipeline is None: | |
| try: | |
| logging.info(f"Loading model: {self.model_name} (attempt {self._retries + 1})") | |
| torch.cuda.empty_cache() | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True, | |
| cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') | |
| ) | |
| if self.model_type == "text-generation": | |
| self._model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.float16, | |
| cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') | |
| ) | |
| else: | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| self._model = AutoModelForSeq2SeqLM.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| torch_dtype=dtype, | |
| cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface') | |
| ) | |
| device = 0 if torch.cuda.is_available() else -1 | |
| self._pipeline = pipeline( | |
| task=self.model_type, | |
| model=self._model, | |
| tokenizer=self._tokenizer, | |
| ) | |
| logging.info(f"Model loaded successfully: {self.model_name}") | |
| return self._pipeline | |
| except Exception as e: | |
| logging.error(f"Error loading model '{self.model_name}': {e}", exc_info=True) | |
| self._retries += 1 | |
| if self._retries >= self.max_retries: | |
| raise RuntimeError(f"Exceeded retry limit for model: {self.model_name}") | |
| # Attempt fallback if it's different from current | |
| if self.fallback_model and self.fallback_model != self.model_name: | |
| logging.warning(f"Falling back to model: {self.fallback_model}") | |
| self.model_name = self.fallback_model | |
| return self.load() | |
| else: | |
| raise RuntimeError(f"Fallback failed or not set for model: {self.model_name}") | |
| return self._pipeline | |
| class WhisperModelLoader: | |
| _instance = None | |
| def __init__(self): | |
| self._model = None | |
| def get_instance(): | |
| if WhisperModelLoader._instance is None: | |
| WhisperModelLoader._instance = WhisperModelLoader() | |
| return WhisperModelLoader._instance | |
| def load(self): | |
| if self._model is None: | |
| try: | |
| logging.info("Loading Whisper model...") | |
| self._model = whisper.load_model( | |
| "tiny", # Using tiny model for better memory usage | |
| download_root=os.environ.get('WHISPER_CACHE', '/tmp/whisper') | |
| ) | |
| logging.info("Whisper model loaded successfully") | |
| except Exception as e: | |
| logging.error(f"Failed to load Whisper model: {str(e)}", exc_info=True) | |
| raise | |
| return self._model | |
| def transcribe(self, audio_path): | |
| model = self.load() | |
| return model.transcribe(audio_path) | |
| # Initialize agents | |
| try: | |
| # Use smaller models for Hugging Face Spaces | |
| medical_data_extractor_model_loader = LazyModelLoader( | |
| "facebook/bart-base", # Start with a smaller model | |
| "text-generation", | |
| fallback_model="facebook/bart-large-cnn" | |
| ) | |
| summarization_model_loader = LazyModelLoader( | |
| "Falconsai/medical_summarization", # ✅ Known working | |
| "summarization", | |
| fallback_model="Falconsai/medical_summarization" | |
| ) | |
| # Initialize agents with lazy loading | |
| text_extractor_agent = TextExtractorAgent() | |
| phi_scrubber_agent = PHIScrubberAgent() | |
| medical_data_extractor_agent = MedicalDataExtractorAgent(medical_data_extractor_model_loader) | |
| summarizer_agent = SummarizerAgent(summarization_model_loader) | |
| # Pass all agents and models to routes | |
| agents = { | |
| "text_extractor": text_extractor_agent, | |
| "phi_scrubber": phi_scrubber_agent, | |
| "summarizer": summarizer_agent, | |
| "medical_data_extractor": medical_data_extractor_agent, | |
| "whisper_model": WhisperModelLoader.get_instance(), | |
| "patient_summarizer": PatientSummarizerAgent(model_name="falconsai/medical_summarization",), | |
| } | |
| from .api.routes import register_routes | |
| register_routes(app, agents) | |
| except Exception as e: | |
| logging.error(f"Failed to initialize application: {str(e)}", exc_info=True) | |
| raise | |
| def handle_error(error): | |
| logging.error(f"Unhandled error: {str(error)}", exc_info=True) | |
| return jsonify({ | |
| "error": str(error), | |
| "status": "error" | |
| }), 500 | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860, debug=False) |