abhisheksan commited on
Commit
0d8e806
·
1 Parent(s): ad71c98

Implement model initialization and health check endpoints; add model download logic, error handling, and request validation

Browse files
Files changed (1) hide show
  1. main.py +86 -1
main.py CHANGED
@@ -1,8 +1,11 @@
1
  import os
 
2
  from fastapi import FastAPI, HTTPException, status
3
  from pathlib import Path
4
  import logging
5
  import sys
 
 
6
  from ctransformers import AutoModelForCausalLM
7
  BASE_DIR = Path("/app")
8
  MODEL_DIR = BASE_DIR / "models"
@@ -169,4 +172,86 @@ async def startup_event():
169
  "Check model_loading.log for details."
170
  )
171
  else:
172
- logger.info("Application started successfully with model loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Optional
3
  from fastapi import FastAPI, HTTPException, status
4
  from pathlib import Path
5
  import logging
6
  import sys
7
+
8
+ from pydantic import BaseModel
9
  from ctransformers import AutoModelForCausalLM
10
  BASE_DIR = Path("/app")
11
  MODEL_DIR = BASE_DIR / "models"
 
172
  "Check model_loading.log for details."
173
  )
174
  else:
175
+ logger.info("Application started successfully with model loaded")
176
+
177
+ @app.on_event("startup")
178
+ async def startup():
179
+ """Initialize the model during startup"""
180
+ global model
181
+
182
+ logger.info("Starting application...")
183
+
184
+ if not MODEL_PATH:
185
+ logger.error("MODEL_PATH is not defined in config")
186
+ return
187
+
188
+ # Try to download model if it doesn't exist
189
+ if not MODEL_PATH.exists():
190
+ try:
191
+ logger.info("Model not found, attempting download...")
192
+ download_model()
193
+ except Exception as e:
194
+ logger.error(f"Failed to download model: {str(e)}")
195
+ return
196
+
197
+ # Initialize model
198
+ model = initialize_model(MODEL_PATH)
199
+
200
+ if model is None:
201
+ logger.warning(
202
+ "Model failed to load. Service will start but /generate endpoint will be unavailable. "
203
+ "Check model_loading.log for details."
204
+ )
205
+ else:
206
+ logger.info("Application started successfully with model loaded")
207
+
208
+ @app.get("/health")
209
+ async def health_check():
210
+ """Health check endpoint"""
211
+ return {
212
+ "status": "healthy",
213
+ "model_loaded": model is not None
214
+ }
215
+ class GenerateRequest(BaseModel):
216
+ prompt: str
217
+ max_tokens: Optional[int] = 512
218
+ temperature: Optional[float] = 0.7
219
+ @app.post("/generate")
220
+ async def generate_text(request: GenerateRequest):
221
+ """Generate text from the model"""
222
+ if model is None:
223
+ raise HTTPException(
224
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
225
+ detail="Model is not loaded. Please check server logs."
226
+ )
227
+
228
+ try:
229
+ # Generate response from the model
230
+ response = model(
231
+ request.prompt,
232
+ max_new_tokens=request.max_tokens,
233
+ temperature=request.temperature
234
+ )
235
+
236
+ return {
237
+ "generated_text": response,
238
+ "prompt": request.prompt,
239
+ "max_tokens": request.max_tokens,
240
+ "temperature": request.temperature
241
+ }
242
+ except Exception as e:
243
+ logger.error(f"Error generating text: {str(e)}")
244
+ raise HTTPException(
245
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
246
+ detail=f"Error generating text: {str(e)}"
247
+ )
248
+
249
+ # Optional shutdown event
250
+ @app.on_event("shutdown")
251
+ async def shutdown():
252
+ """Cleanup on shutdown"""
253
+ global model
254
+ if model is not None:
255
+ del model
256
+ model = None
257
+ logger.info("Model unloaded during shutdown")