sachinchandrankallar's picture
upload error
1aff96b
raw
history blame
7.74 kB
import os
import logging
from flask import Flask, jsonify
from flask_cors import CORS
import whisper
from dotenv import load_dotenv
from .agents.text_extractor import TextExtractorAgent
from .agents.phi_scrubber import PHIScrubberAgent
from .agents.phi_scrubber import MedicalTextUtils
from .agents.summarizer import SummarizerAgent
from .agents.medical_data_extractor import MedicalDataExtractorAgent
from .agents.medical_data_extractor import MedicalDocDataExtractorAgent
from .agents.patient_summary_agent import PatientSummarizerAgent
import torch
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler('/tmp/app.log')
]
)
app = Flask(__name__)
CORS(app)
# Configure upload directory with safe fallbacks (avoid creating /data at import time)
def _resolve_upload_dir() -> str:
try:
# Prefer /data/uploads if it already exists and is writable
data_dir = '/data/uploads'
if os.path.isdir('/data') and (os.path.isdir(data_dir) or os.access('/data', os.W_OK)):
os.makedirs(data_dir, exist_ok=True)
return data_dir
except Exception:
pass
# Fallback to /tmp/uploads which is always writable on Spaces
tmp_dir = '/tmp/uploads'
os.makedirs(tmp_dir, exist_ok=True)
return tmp_dir
app.config['UPLOAD_FOLDER'] = _resolve_upload_dir()
app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 # 100 MB max file size
# Set cache directories
CACHE_DIRS = {
'HF_HOME': '/tmp/huggingface',
'HF_HOME': '/tmp/huggingface',
'XDG_CACHE_HOME': '/tmp',
'TORCH_HOME': '/tmp/torch',
'WHISPER_CACHE': '/tmp/whisper'
}
for env_var, path in CACHE_DIRS.items():
os.environ[env_var] = path
os.makedirs(path, exist_ok=True)
# Model loaders
class LazyModelLoader:
def __init__(self, model_name, model_type, fallback_model=None, max_retries=2):
self.model_name = model_name
self.model_type = model_type
self.fallback_model = fallback_model
self._model = None
self._tokenizer = None
self._pipeline = None
self._retries = 0
self.max_retries = max_retries
def load(self):
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import torch
if self._pipeline is None:
try:
logging.info(f"Loading model: {self.model_name} (attempt {self._retries + 1})")
torch.cuda.empty_cache()
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
)
if self.model_type == "text-generation":
self._model = AutoModelForCausalLM.from_pretrained(
self.model_name,
trust_remote_code=True,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
)
else:
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self._model = AutoModelForSeq2SeqLM.from_pretrained(
self.model_name,
trust_remote_code=True,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=dtype,
cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
)
device = 0 if torch.cuda.is_available() else -1
self._pipeline = pipeline(
task=self.model_type,
model=self._model,
tokenizer=self._tokenizer,
)
logging.info(f"Model loaded successfully: {self.model_name}")
return self._pipeline
except Exception as e:
logging.error(f"Error loading model '{self.model_name}': {e}", exc_info=True)
self._retries += 1
if self._retries >= self.max_retries:
raise RuntimeError(f"Exceeded retry limit for model: {self.model_name}")
# Attempt fallback if it's different from current
if self.fallback_model and self.fallback_model != self.model_name:
logging.warning(f"Falling back to model: {self.fallback_model}")
self.model_name = self.fallback_model
return self.load()
else:
raise RuntimeError(f"Fallback failed or not set for model: {self.model_name}")
return self._pipeline
class WhisperModelLoader:
_instance = None
def __init__(self):
self._model = None
@staticmethod
def get_instance():
if WhisperModelLoader._instance is None:
WhisperModelLoader._instance = WhisperModelLoader()
return WhisperModelLoader._instance
def load(self):
if self._model is None:
try:
logging.info("Loading Whisper model...")
self._model = whisper.load_model(
"tiny", # Using tiny model for better memory usage
download_root=os.environ.get('WHISPER_CACHE', '/tmp/whisper')
)
logging.info("Whisper model loaded successfully")
except Exception as e:
logging.error(f"Failed to load Whisper model: {str(e)}", exc_info=True)
raise
return self._model
def transcribe(self, audio_path):
model = self.load()
return model.transcribe(audio_path)
# Initialize agents
try:
# Use smaller models for Hugging Face Spaces
medical_data_extractor_model_loader = LazyModelLoader(
"facebook/bart-base", # Start with a smaller model
"text-generation",
fallback_model="facebook/bart-large-cnn"
)
summarization_model_loader = LazyModelLoader(
"Falconsai/medical_summarization", # ✅ Known working
"summarization",
fallback_model="Falconsai/medical_summarization"
)
# Initialize agents with lazy loading
text_extractor_agent = TextExtractorAgent()
phi_scrubber_agent = PHIScrubberAgent()
medical_data_extractor_agent = MedicalDataExtractorAgent(medical_data_extractor_model_loader)
summarizer_agent = SummarizerAgent(summarization_model_loader)
# Pass all agents and models to routes
agents = {
"text_extractor": text_extractor_agent,
"phi_scrubber": phi_scrubber_agent,
"summarizer": summarizer_agent,
"medical_data_extractor": medical_data_extractor_agent,
"whisper_model": WhisperModelLoader.get_instance(),
"patient_summarizer": PatientSummarizerAgent(model_name="falconsai/medical_summarization",),
}
from .api.routes import register_routes
register_routes(app, agents)
except Exception as e:
logging.error(f"Failed to initialize application: {str(e)}", exc_info=True)
raise
@app.errorhandler(Exception)
def handle_error(error):
logging.error(f"Unhandled error: {str(error)}", exc_info=True)
return jsonify({
"error": str(error),
"status": "error"
}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=False)