Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Test script for GGUF model loading in Hugging Face Spaces | |
This helps identify issues before they cause 500 errors in production | |
""" | |
import os | |
import sys | |
import time | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def test_gguf_loading(): | |
"""Test GGUF model loading with the same parameters used in production""" | |
# Set environment variables for Hugging Face Spaces | |
os.environ['HF_HOME'] = '/tmp/huggingface' | |
os.environ['GGUF_N_THREADS'] = '2' | |
os.environ['GGUF_N_BATCH'] = '64' | |
try: | |
logger.info("Testing GGUF model loading...") | |
# Test the exact model name from your API call | |
model_name = "microsoft/Phi-3-mini-4k-instruct-gguf" | |
filename = "Phi-3-mini-4k-instruct-q4.gguf" | |
logger.info(f"Model: {model_name}") | |
logger.info(f"Filename: {filename}") | |
# Test import | |
try: | |
from ai_med_extract.utils.model_loader_gguf import GGUFModelPipeline | |
logger.info("β GGUFModelPipeline import successful") | |
except ImportError as e: | |
logger.error(f"β Failed to import GGUFModelPipeline: {e}") | |
return False | |
# Test model loading with timeout | |
start_time = time.time() | |
try: | |
pipeline = GGUFModelPipeline(model_name, filename, timeout=300) | |
load_time = time.time() - start_time | |
logger.info(f"β Model loaded successfully in {load_time:.2f}s") | |
except Exception as e: | |
load_time = time.time() - start_time | |
logger.error(f"β Model loading failed after {load_time:.2f}s: {e}") | |
return False | |
# Test basic generation | |
try: | |
test_prompt = "Generate a brief medical summary: Patient has fever and cough." | |
logger.info("Testing basic generation...") | |
start_gen = time.time() | |
result = pipeline.generate(test_prompt, max_tokens=100) | |
gen_time = time.time() - start_gen | |
logger.info(f"β Generation successful in {gen_time:.2f}s") | |
logger.info(f"Generated text length: {len(result)} characters") | |
logger.info(f"Sample output: {result[:200]}...") | |
except Exception as e: | |
logger.error(f"β Generation failed: {e}") | |
return False | |
# Test full summary generation | |
try: | |
logger.info("Testing full summary generation...") | |
start_summary = time.time() | |
summary = pipeline.generate_full_summary(test_prompt, max_tokens=200, max_loops=1) | |
summary_time = time.time() - start_summary | |
logger.info(f"β Full summary generation successful in {summary_time:.2f}s") | |
logger.info(f"Summary length: {len(summary)} characters") | |
except Exception as e: | |
logger.error(f"β Full summary generation failed: {e}") | |
return False | |
logger.info("π All tests passed! GGUF model is working correctly.") | |
return True | |
except Exception as e: | |
logger.error(f"β Test failed with unexpected error: {e}") | |
return False | |
def test_fallback_pipeline(): | |
"""Test the fallback pipeline when GGUF fails""" | |
try: | |
logger.info("Testing fallback pipeline...") | |
from ai_med_extract.utils.model_loader_gguf import create_fallback_pipeline | |
fallback = create_fallback_pipeline() | |
result = fallback.generate("Test prompt") | |
logger.info(f"β Fallback pipeline working: {len(result)} characters generated") | |
return True | |
except Exception as e: | |
logger.error(f"β Fallback pipeline failed: {e}") | |
return False | |
def main(): | |
"""Main test function""" | |
logger.info("Starting GGUF model tests...") | |
# Test 1: GGUF model loading | |
gguf_success = test_gguf_loading() | |
# Test 2: Fallback pipeline | |
fallback_success = test_fallback_pipeline() | |
# Summary | |
logger.info("\n" + "="*50) | |
logger.info("TEST SUMMARY") | |
logger.info("="*50) | |
logger.info(f"GGUF Model Loading: {'β PASS' if gguf_success else 'β FAIL'}") | |
logger.info(f"Fallback Pipeline: {'β PASS' if fallback_success else 'β PASS'}") | |
if gguf_success: | |
logger.info("π GGUF model is working correctly!") | |
logger.info("Your API should work without 500 errors.") | |
else: | |
logger.warning("β οΈ GGUF model has issues. The fallback will be used.") | |
logger.info("Your API will still work but with reduced functionality.") | |
return gguf_success | |
if __name__ == "__main__": | |
success = main() | |
sys.exit(0 if success else 1) | |