from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer import torch from huggingface_hub import snapshot_download from safetensors.torch import load_file import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelInput(BaseModel): prompt: str = Field(..., description="The input prompt for text generation") max_new_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate") app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define model paths BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct" ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs" def load_model_and_tokenizer(): """Load the model, tokenizer, and adapter weights.""" try: logger.info("Loading base model...") model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_PATH, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto" ) logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) logger.info("Downloading adapter weights...") adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH) logger.info("Loading adapter weights...") adapter_file = f"{adapter_path_local}/adapter_model.safetensors" state_dict = load_file(adapter_file) logger.info("Applying adapter weights...") model.load_state_dict(state_dict, strict=False) logger.info("Model and adapter loaded successfully!") return model, tokenizer except Exception as e: logger.error(f"Error during model loading: {e}", exc_info=True) raise # Load model and tokenizer at startup try: model, tokenizer = load_model_and_tokenizer() except Exception as e: logger.error(f"Failed to load model at startup: {e}", exc_info=True) model = None tokenizer = None def generate_response(model, tokenizer, instruction, max_new_tokens=2048): """Generate a response from the model based on an instruction.""" try: logger.info(f"Generating response for instruction: {instruction[:100]}...") # Encode input with truncation inputs = tokenizer.encode( instruction, return_tensors="pt", truncation=True, max_length=tokenizer.model_max_length ).to(model.device) logger.info(f"Input shape: {inputs.shape}") # Create attention mask attention_mask = torch.ones(inputs.shape, device=model.device) # Generate response outputs = model.generate( inputs, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) logger.info(f"Output shape: {outputs.shape}") # Decode and strip input prompt from response response = tokenizer.decode(outputs[0], skip_special_tokens=True) generated_text = response[len(instruction):].strip() logger.info(f"Generated text length: {len(generated_text)}") return generated_text except Exception as e: logger.error(f"Error generating response: {e}", exc_info=True) raise ValueError(f"Error generating response: {e}") @app.post("/generate") async def generate_text(input: ModelInput, request: Request): """Generate text based on the input prompt.""" try: if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") logger.info(f"Received request from {request.client.host}") logger.info(f"Prompt: {input.prompt[:100]}...") response = generate_response( model=model, tokenizer=tokenizer, instruction=input.prompt, max_new_tokens=input.max_new_tokens ) if not response: logger.warning("Generated empty response") return {"generated_text": "", "warning": "Empty response generated"} logger.info(f"Generated response length: {len(response)}") return {"generated_text": response} except Exception as e: logger.error(f"Error in generate_text endpoint: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): """Root endpoint that returns a welcome message.""" return {"message": "Welcome to the Model API!", "status": "running"} @app.get("/health") async def health_check(): """Health check endpoint.""" return { "status": "healthy", "model_loaded": model is not None and tokenizer is not None, "model_device": str(next(model.parameters()).device) if model else None }