poetica / main.py
abhisheksan's picture
Improve model preloading in PoetryGenerationService with meaningful return value and enhanced error handling
abc61cb
raw
history blame
2.28 kB
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.api.endpoints.poetry import router as poetry_router
import os
import logging
from typing import Tuple
from starlette.responses import Response
from starlette.staticfiles import StaticFiles
from huggingface_hub import login
from functools import lru_cache
from app.services.poetry_generation import PoetryGenerationService
# Configure logging once at module level
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@lru_cache()
def get_hf_token() -> str:
"""Get Hugging Face token from environment variables."""
token = os.getenv("HF_TOKEN")
if not token:
raise EnvironmentError(
"HF_TOKEN environment variable not found. "
"Please set your Hugging Face access token."
)
return token
def init_huggingface():
"""Initialize Hugging Face authentication."""
try:
token = get_hf_token()
login(token=token)
logger.info("Successfully logged in to Hugging Face")
except Exception as e:
logger.error(f"Failed to login to Hugging Face: {str(e)}")
raise
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialize Hugging Face authentication first
init_huggingface()
# Initialize poetry service and preload models
poetry_service = PoetryGenerationService()
try:
preload_result = poetry_service.preload_models()
if asyncio.iscoroutine(preload_result):
await preload_result
else:
preload_result # Call directly if synchronous
except Exception as e:
logger.error(f"Error preloading models: {str(e)}")
raise
yield # Continue to application startup
app = FastAPI(lifespan=lifespan)
app.include_router(poetry_router, prefix="/api/v1/poetry")
@app.get("/healthz")
async def lifecheck():
return Response("OK", media_type="text/plain")
def get_port() -> int:
return int(os.getenv("PORT", "8000"))
if __name__ == "__main__":
import uvicorn
port = get_port()
app.mount("/static", StaticFiles(directory="static"), name="static")
logger.info(f"Starting FastAPI server on port {port}")
uvicorn.run(app, host="0.0.0.0", port=port)