Spaces:
Running
Running
Commit
·
0d8e806
1
Parent(s):
ad71c98
Implement model initialization and health check endpoints; add model download logic, error handling, and request validation
Browse files
main.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
import os
|
|
|
2 |
from fastapi import FastAPI, HTTPException, status
|
3 |
from pathlib import Path
|
4 |
import logging
|
5 |
import sys
|
|
|
|
|
6 |
from ctransformers import AutoModelForCausalLM
|
7 |
BASE_DIR = Path("/app")
|
8 |
MODEL_DIR = BASE_DIR / "models"
|
@@ -169,4 +172,86 @@ async def startup_event():
|
|
169 |
"Check model_loading.log for details."
|
170 |
)
|
171 |
else:
|
172 |
-
logger.info("Application started successfully with model loaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from typing import Optional
|
3 |
from fastapi import FastAPI, HTTPException, status
|
4 |
from pathlib import Path
|
5 |
import logging
|
6 |
import sys
|
7 |
+
|
8 |
+
from pydantic import BaseModel
|
9 |
from ctransformers import AutoModelForCausalLM
|
10 |
BASE_DIR = Path("/app")
|
11 |
MODEL_DIR = BASE_DIR / "models"
|
|
|
172 |
"Check model_loading.log for details."
|
173 |
)
|
174 |
else:
|
175 |
+
logger.info("Application started successfully with model loaded")
|
176 |
+
|
177 |
+
@app.on_event("startup")
|
178 |
+
async def startup():
|
179 |
+
"""Initialize the model during startup"""
|
180 |
+
global model
|
181 |
+
|
182 |
+
logger.info("Starting application...")
|
183 |
+
|
184 |
+
if not MODEL_PATH:
|
185 |
+
logger.error("MODEL_PATH is not defined in config")
|
186 |
+
return
|
187 |
+
|
188 |
+
# Try to download model if it doesn't exist
|
189 |
+
if not MODEL_PATH.exists():
|
190 |
+
try:
|
191 |
+
logger.info("Model not found, attempting download...")
|
192 |
+
download_model()
|
193 |
+
except Exception as e:
|
194 |
+
logger.error(f"Failed to download model: {str(e)}")
|
195 |
+
return
|
196 |
+
|
197 |
+
# Initialize model
|
198 |
+
model = initialize_model(MODEL_PATH)
|
199 |
+
|
200 |
+
if model is None:
|
201 |
+
logger.warning(
|
202 |
+
"Model failed to load. Service will start but /generate endpoint will be unavailable. "
|
203 |
+
"Check model_loading.log for details."
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
logger.info("Application started successfully with model loaded")
|
207 |
+
|
208 |
+
@app.get("/health")
|
209 |
+
async def health_check():
|
210 |
+
"""Health check endpoint"""
|
211 |
+
return {
|
212 |
+
"status": "healthy",
|
213 |
+
"model_loaded": model is not None
|
214 |
+
}
|
215 |
+
class GenerateRequest(BaseModel):
|
216 |
+
prompt: str
|
217 |
+
max_tokens: Optional[int] = 512
|
218 |
+
temperature: Optional[float] = 0.7
|
219 |
+
@app.post("/generate")
|
220 |
+
async def generate_text(request: GenerateRequest):
|
221 |
+
"""Generate text from the model"""
|
222 |
+
if model is None:
|
223 |
+
raise HTTPException(
|
224 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
225 |
+
detail="Model is not loaded. Please check server logs."
|
226 |
+
)
|
227 |
+
|
228 |
+
try:
|
229 |
+
# Generate response from the model
|
230 |
+
response = model(
|
231 |
+
request.prompt,
|
232 |
+
max_new_tokens=request.max_tokens,
|
233 |
+
temperature=request.temperature
|
234 |
+
)
|
235 |
+
|
236 |
+
return {
|
237 |
+
"generated_text": response,
|
238 |
+
"prompt": request.prompt,
|
239 |
+
"max_tokens": request.max_tokens,
|
240 |
+
"temperature": request.temperature
|
241 |
+
}
|
242 |
+
except Exception as e:
|
243 |
+
logger.error(f"Error generating text: {str(e)}")
|
244 |
+
raise HTTPException(
|
245 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
246 |
+
detail=f"Error generating text: {str(e)}"
|
247 |
+
)
|
248 |
+
|
249 |
+
# Optional shutdown event
|
250 |
+
@app.on_event("shutdown")
|
251 |
+
async def shutdown():
|
252 |
+
"""Cleanup on shutdown"""
|
253 |
+
global model
|
254 |
+
if model is not None:
|
255 |
+
del model
|
256 |
+
model = None
|
257 |
+
logger.info("Model unloaded during shutdown")
|