Spaces:
Running
Running
Improve model preloading in PoetryGenerationService with meaningful return value and enhanced error handling
abc61cb
| import asyncio | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI | |
| from app.api.endpoints.poetry import router as poetry_router | |
| import os | |
| import logging | |
| from typing import Tuple | |
| from starlette.responses import Response | |
| from starlette.staticfiles import StaticFiles | |
| from huggingface_hub import login | |
| from functools import lru_cache | |
| from app.services.poetry_generation import PoetryGenerationService | |
| # Configure logging once at module level | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def get_hf_token() -> str: | |
| """Get Hugging Face token from environment variables.""" | |
| token = os.getenv("HF_TOKEN") | |
| if not token: | |
| raise EnvironmentError( | |
| "HF_TOKEN environment variable not found. " | |
| "Please set your Hugging Face access token." | |
| ) | |
| return token | |
| def init_huggingface(): | |
| """Initialize Hugging Face authentication.""" | |
| try: | |
| token = get_hf_token() | |
| login(token=token) | |
| logger.info("Successfully logged in to Hugging Face") | |
| except Exception as e: | |
| logger.error(f"Failed to login to Hugging Face: {str(e)}") | |
| raise | |
| async def lifespan(app: FastAPI): | |
| # Initialize Hugging Face authentication first | |
| init_huggingface() | |
| # Initialize poetry service and preload models | |
| poetry_service = PoetryGenerationService() | |
| try: | |
| preload_result = poetry_service.preload_models() | |
| if asyncio.iscoroutine(preload_result): | |
| await preload_result | |
| else: | |
| preload_result # Call directly if synchronous | |
| except Exception as e: | |
| logger.error(f"Error preloading models: {str(e)}") | |
| raise | |
| yield # Continue to application startup | |
| app = FastAPI(lifespan=lifespan) | |
| app.include_router(poetry_router, prefix="/api/v1/poetry") | |
| async def lifecheck(): | |
| return Response("OK", media_type="text/plain") | |
| def get_port() -> int: | |
| return int(os.getenv("PORT", "8000")) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = get_port() | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| logger.info(f"Starting FastAPI server on port {port}") | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |