Jakaria
commited on
Commit
·
ad7a9e2
1
Parent(s):
cdd0b85
Add Bangla model API
Browse files
app.py
CHANGED
@@ -1,24 +1,145 @@
|
|
1 |
-
from fastapi import FastAPI
|
2 |
from pydantic import BaseModel
|
3 |
import joblib
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
app = FastAPI()
|
6 |
|
7 |
class PredictRequest(BaseModel):
|
8 |
text: str
|
9 |
|
10 |
-
#
|
11 |
-
model =
|
12 |
-
vectorizer =
|
13 |
-
label_encoder =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
@app.get("/")
|
16 |
def root():
|
17 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
@app.post("/predict")
|
20 |
def predict(request: PredictRequest):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
import joblib
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
import logging
|
7 |
+
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
13 |
class PredictRequest(BaseModel):
|
14 |
text: str
|
15 |
|
16 |
+
# Global model variables
|
17 |
+
model = None
|
18 |
+
vectorizer = None
|
19 |
+
label_encoder = None
|
20 |
+
models_loaded = False
|
21 |
+
|
22 |
+
def load_model_safe(filename):
|
23 |
+
"""Safely load model with multiple methods"""
|
24 |
+
if not os.path.exists(filename):
|
25 |
+
raise FileNotFoundError(f"{filename} not found")
|
26 |
+
|
27 |
+
# Try joblib first
|
28 |
+
try:
|
29 |
+
return joblib.load(filename)
|
30 |
+
except Exception as e1:
|
31 |
+
logger.warning(f"Joblib failed for {filename}: {e1}")
|
32 |
+
|
33 |
+
# Try pickle as fallback
|
34 |
+
try:
|
35 |
+
with open(filename, 'rb') as f:
|
36 |
+
return pickle.load(f)
|
37 |
+
except Exception as e2:
|
38 |
+
logger.error(f"Pickle also failed for {filename}: {e2}")
|
39 |
+
raise e1 # Raise original joblib error
|
40 |
+
|
41 |
+
@app.on_event("startup")
|
42 |
+
async def startup_event():
|
43 |
+
global model, vectorizer, label_encoder, models_loaded
|
44 |
+
|
45 |
+
try:
|
46 |
+
logger.info("Starting model loading...")
|
47 |
+
|
48 |
+
# Load each model individually with error handling
|
49 |
+
logger.info("Loading bangla_model.pkl...")
|
50 |
+
model = load_model_safe("bangla_model.pkl")
|
51 |
+
logger.info(f"Model type: {type(model)}")
|
52 |
+
|
53 |
+
logger.info("Loading bangla_vectorizer.pkl...")
|
54 |
+
vectorizer = load_model_safe("bangla_vectorizer.pkl")
|
55 |
+
logger.info(f"Vectorizer type: {type(vectorizer)}")
|
56 |
+
|
57 |
+
logger.info("Loading bangla_label_encoder.pkl...")
|
58 |
+
label_encoder = load_model_safe("bangla_label_encoder.pkl")
|
59 |
+
logger.info(f"Label encoder type: {type(label_encoder)}")
|
60 |
+
|
61 |
+
# Test pipeline with dummy data
|
62 |
+
logger.info("Testing pipeline...")
|
63 |
+
test_vect = vectorizer.transform(["test"])
|
64 |
+
test_pred = model.predict(test_vect)
|
65 |
+
test_label = label_encoder.inverse_transform(test_pred)
|
66 |
+
logger.info(f"Pipeline test successful: {test_label[0]}")
|
67 |
+
|
68 |
+
models_loaded = True
|
69 |
+
logger.info("All models loaded successfully!")
|
70 |
+
|
71 |
+
except Exception as e:
|
72 |
+
logger.error(f"Failed to load models: {str(e)}")
|
73 |
+
models_loaded = False
|
74 |
+
# Don't raise here - let the app start and handle errors in endpoints
|
75 |
|
76 |
@app.get("/")
|
77 |
def root():
|
78 |
+
return {
|
79 |
+
"message": "Bangla model API is running!",
|
80 |
+
"models_loaded": models_loaded,
|
81 |
+
"status": "healthy" if models_loaded else "models_not_loaded"
|
82 |
+
}
|
83 |
+
|
84 |
+
@app.get("/status")
|
85 |
+
def status():
|
86 |
+
"""Detailed status endpoint"""
|
87 |
+
return {
|
88 |
+
"models_loaded": models_loaded,
|
89 |
+
"model_available": model is not None,
|
90 |
+
"vectorizer_available": vectorizer is not None,
|
91 |
+
"label_encoder_available": label_encoder is not None,
|
92 |
+
"current_directory": os.getcwd(),
|
93 |
+
"available_files": [f for f in os.listdir('.') if f.endswith('.pkl')]
|
94 |
+
}
|
95 |
|
96 |
@app.post("/predict")
|
97 |
def predict(request: PredictRequest):
|
98 |
+
if not models_loaded:
|
99 |
+
raise HTTPException(
|
100 |
+
status_code=503,
|
101 |
+
detail="Models not loaded. Check /status endpoint for details."
|
102 |
+
)
|
103 |
+
|
104 |
+
if not request.text or not request.text.strip():
|
105 |
+
raise HTTPException(status_code=400, detail="Text cannot be empty")
|
106 |
+
|
107 |
+
try:
|
108 |
+
logger.info(f"Processing text: {request.text[:50]}...")
|
109 |
+
|
110 |
+
# Transform text
|
111 |
+
vect = vectorizer.transform([request.text])
|
112 |
+
logger.info(f"Vectorization successful, shape: {vect.shape}")
|
113 |
+
|
114 |
+
# Make prediction
|
115 |
+
pred = model.predict(vect)
|
116 |
+
logger.info(f"Prediction successful: {pred}")
|
117 |
+
|
118 |
+
# Transform label
|
119 |
+
label = label_encoder.inverse_transform(pred)
|
120 |
+
logger.info(f"Label transformation successful: {label[0]}")
|
121 |
+
|
122 |
+
return {"prediction": label[0]}
|
123 |
+
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(f"Prediction error: {str(e)}")
|
126 |
+
raise HTTPException(
|
127 |
+
status_code=500,
|
128 |
+
detail=f"Prediction failed: {str(e)}"
|
129 |
+
)
|
130 |
+
|
131 |
+
# Add a manual model reload endpoint for debugging
|
132 |
+
@app.post("/reload-models")
|
133 |
+
def reload_models():
|
134 |
+
"""Manually reload models - useful for debugging"""
|
135 |
+
global model, vectorizer, label_encoder, models_loaded
|
136 |
+
|
137 |
+
try:
|
138 |
+
model = load_model_safe("bangla_model.pkl")
|
139 |
+
vectorizer = load_model_safe("bangla_vectorizer.pkl")
|
140 |
+
label_encoder = load_model_safe("bangla_label_encoder.pkl")
|
141 |
+
models_loaded = True
|
142 |
+
return {"message": "Models reloaded successfully"}
|
143 |
+
except Exception as e:
|
144 |
+
models_loaded = False
|
145 |
+
raise HTTPException(status_code=500, detail=f"Failed to reload models: {str(e)}")
|