from fastapi import FastAPI, HTTPException, status from pydantic import BaseModel, Field from typing import Optional, List from ctransformers import AutoModelForCausalLM import time import logging from .app.config import MODEL_PATH # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Poetry Generator API", description="An API for generating poetry using a local LLM", version="1.0.0" ) # Global model variable model = None class PoetryRequest(BaseModel): prompt: str = Field(..., description="The topic or theme for the poem", min_length=1) style: str = Field( default="free verse", description="Style of the poem to generate" ) max_length: int = Field( default=200, description="Maximum length of the generated poem", ge=50, le=500 ) temperature: float = Field( default=0.7, description="Temperature for text generation", ge=0.1, le=2.0 ) class PoetryResponse(BaseModel): poem: str generation_time: float prompt: str style: str class ModelInfo(BaseModel): status: str model_name: str model_path: str supported_styles: List[str] max_context_length: int @app.on_event("startup") async def startup_event(): """Initialize the model during startup""" global model try: if not MODEL_PATH.exists(): raise FileNotFoundError( f"Model file not found at {MODEL_PATH}. " "Please run download_model.py first." ) logger.info(f"Loading model from {MODEL_PATH}") model = AutoModelForCausalLM.from_pretrained( str(MODEL_PATH.parent), model_file=MODEL_PATH.name, model_type="llama", max_new_tokens=512, context_length=512, gpu_layers=0 # CPU only ) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {str(e)}") raise RuntimeError("Failed to initialize model") @app.get( "/health", response_model=ModelInfo, status_code=status.HTTP_200_OK, tags=["Health Check"] ) async def health_check(): """Check if the model is loaded and get basic information""" if model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model not loaded" ) return ModelInfo( status="ready", model_name="Llama-2-7B-Chat", model_path=str(MODEL_PATH), supported_styles=[ "free verse", "haiku", "sonnet", "limerick", "tanka" ], max_context_length=512 ) @app.post( "/generate", response_model=PoetryResponse, status_code=status.HTTP_200_OK, tags=["Generation"] ) async def generate_poem(request: PoetryRequest): """Generate a poem based on the provided prompt and parameters""" if model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model not loaded" ) try: start_time = time.time() prompt_templates = { "haiku": "Write a haiku about {prompt}. Follow the 5-7-5 syllable pattern:\n\n", "sonnet": "Write a Shakespearean sonnet about {prompt}. Follow the traditional 14-line format with rhyme scheme ABAB CDCD EFEF GG:\n\n", "limerick": "Write a limerick about {prompt}. Follow the AABBA rhyme scheme:\n\n", "free verse": "Write a free verse poem about {prompt}. Make it creative and meaningful:\n\n", "tanka": "Write a tanka about {prompt}. Follow the 5-7-5-7-7 syllable pattern:\n\n" } template = prompt_templates.get(request.style.lower(), prompt_templates["free verse"]) full_prompt = template.format(prompt=request.prompt) output = model( full_prompt, max_new_tokens=request.max_length, temperature=request.temperature, top_p=0.95, repeat_penalty=1.2 ) generation_time = time.time() - start_time return PoetryResponse( poem=output.strip(), generation_time=generation_time, prompt=request.prompt, style=request.style ) except Exception as e: logger.error(f"Generation error: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to generate poem: {str(e)}" ) if __name__ == "__main__": import uvicorn uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)