Jakaria commited on
Commit
298ba53
·
1 Parent(s): ad7a9e2

Add Bangla model API

Browse files
Files changed (1) hide show
  1. app.py +117 -44
app.py CHANGED
@@ -2,8 +2,9 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import joblib
4
  import os
5
- import pickle
6
  import logging
 
 
7
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
@@ -24,46 +25,32 @@ def load_model_safe(filename):
24
  if not os.path.exists(filename):
25
  raise FileNotFoundError(f"{filename} not found")
26
 
27
- # Try joblib first
28
  try:
29
  return joblib.load(filename)
30
  except Exception as e1:
31
  logger.warning(f"Joblib failed for {filename}: {e1}")
32
 
33
- # Try pickle as fallback
34
  try:
35
  with open(filename, 'rb') as f:
36
  return pickle.load(f)
37
  except Exception as e2:
38
  logger.error(f"Pickle also failed for {filename}: {e2}")
39
- raise e1 # Raise original joblib error
40
 
41
  @app.on_event("startup")
42
  async def startup_event():
43
  global model, vectorizer, label_encoder, models_loaded
44
 
45
  try:
46
- logger.info("Starting model loading...")
47
-
48
- # Load each model individually with error handling
49
- logger.info("Loading bangla_model.pkl...")
50
  model = load_model_safe("bangla_model.pkl")
51
- logger.info(f"Model type: {type(model)}")
52
-
53
- logger.info("Loading bangla_vectorizer.pkl...")
54
  vectorizer = load_model_safe("bangla_vectorizer.pkl")
55
- logger.info(f"Vectorizer type: {type(vectorizer)}")
56
-
57
- logger.info("Loading bangla_label_encoder.pkl...")
58
  label_encoder = load_model_safe("bangla_label_encoder.pkl")
59
- logger.info(f"Label encoder type: {type(label_encoder)}")
60
 
61
- # Test pipeline with dummy data
62
- logger.info("Testing pipeline...")
63
  test_vect = vectorizer.transform(["test"])
64
  test_pred = model.predict(test_vect)
65
  test_label = label_encoder.inverse_transform(test_pred)
66
- logger.info(f"Pipeline test successful: {test_label[0]}")
67
 
68
  models_loaded = True
69
  logger.info("All models loaded successfully!")
@@ -71,7 +58,6 @@ async def startup_event():
71
  except Exception as e:
72
  logger.error(f"Failed to load models: {str(e)}")
73
  models_loaded = False
74
- # Don't raise here - let the app start and handle errors in endpoints
75
 
76
  @app.get("/")
77
  def root():
@@ -83,7 +69,6 @@ def root():
83
 
84
  @app.get("/status")
85
  def status():
86
- """Detailed status endpoint"""
87
  return {
88
  "models_loaded": models_loaded,
89
  "model_available": model is not None,
@@ -93,45 +78,133 @@ def status():
93
  "available_files": [f for f in os.listdir('.') if f.endswith('.pkl')]
94
  }
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  @app.post("/predict")
97
  def predict(request: PredictRequest):
 
98
  if not models_loaded:
99
- raise HTTPException(
100
- status_code=503,
101
- detail="Models not loaded. Check /status endpoint for details."
102
- )
103
-
104
- if not request.text or not request.text.strip():
105
- raise HTTPException(status_code=400, detail="Text cannot be empty")
106
 
107
  try:
108
- logger.info(f"Processing text: {request.text[:50]}...")
 
 
 
 
 
109
 
110
- # Transform text
111
- vect = vectorizer.transform([request.text])
112
- logger.info(f"Vectorization successful, shape: {vect.shape}")
 
 
 
 
113
 
114
- # Make prediction
115
- pred = model.predict(vect)
116
- logger.info(f"Prediction successful: {pred}")
 
 
 
 
117
 
118
- # Transform label
119
- label = label_encoder.inverse_transform(pred)
120
- logger.info(f"Label transformation successful: {label[0]}")
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  return {"prediction": label[0]}
123
 
 
 
124
  except Exception as e:
125
- logger.error(f"Prediction error: {str(e)}")
126
- raise HTTPException(
127
- status_code=500,
128
- detail=f"Prediction failed: {str(e)}"
129
- )
130
 
131
- # Add a manual model reload endpoint for debugging
132
  @app.post("/reload-models")
133
  def reload_models():
134
- """Manually reload models - useful for debugging"""
135
  global model, vectorizer, label_encoder, models_loaded
136
 
137
  try:
 
2
  from pydantic import BaseModel
3
  import joblib
4
  import os
 
5
  import logging
6
+ import numpy as np
7
+ import traceback
8
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
 
25
  if not os.path.exists(filename):
26
  raise FileNotFoundError(f"{filename} not found")
27
 
 
28
  try:
29
  return joblib.load(filename)
30
  except Exception as e1:
31
  logger.warning(f"Joblib failed for {filename}: {e1}")
32
 
 
33
  try:
34
  with open(filename, 'rb') as f:
35
  return pickle.load(f)
36
  except Exception as e2:
37
  logger.error(f"Pickle also failed for {filename}: {e2}")
38
+ raise e1
39
 
40
  @app.on_event("startup")
41
  async def startup_event():
42
  global model, vectorizer, label_encoder, models_loaded
43
 
44
  try:
45
+ logger.info("Loading models...")
 
 
 
46
  model = load_model_safe("bangla_model.pkl")
 
 
 
47
  vectorizer = load_model_safe("bangla_vectorizer.pkl")
 
 
 
48
  label_encoder = load_model_safe("bangla_label_encoder.pkl")
 
49
 
50
+ # Test pipeline
 
51
  test_vect = vectorizer.transform(["test"])
52
  test_pred = model.predict(test_vect)
53
  test_label = label_encoder.inverse_transform(test_pred)
 
54
 
55
  models_loaded = True
56
  logger.info("All models loaded successfully!")
 
58
  except Exception as e:
59
  logger.error(f"Failed to load models: {str(e)}")
60
  models_loaded = False
 
61
 
62
  @app.get("/")
63
  def root():
 
69
 
70
  @app.get("/status")
71
  def status():
 
72
  return {
73
  "models_loaded": models_loaded,
74
  "model_available": model is not None,
 
78
  "available_files": [f for f in os.listdir('.') if f.endswith('.pkl')]
79
  }
80
 
81
+ @app.post("/debug-predict")
82
+ def debug_predict(request: PredictRequest):
83
+ """Debug version of predict with detailed logging"""
84
+ if not models_loaded:
85
+ raise HTTPException(status_code=503, detail="Models not loaded")
86
+
87
+ debug_info = {"steps": []}
88
+
89
+ try:
90
+ # Step 1: Input validation
91
+ debug_info["steps"].append("1. Input validation")
92
+ if not request.text or not request.text.strip():
93
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
94
+
95
+ debug_info["input_text"] = request.text
96
+ debug_info["input_length"] = len(request.text)
97
+
98
+ # Step 2: Text preprocessing
99
+ debug_info["steps"].append("2. Text preprocessing")
100
+ text_to_process = request.text.strip()
101
+ debug_info["processed_text_length"] = len(text_to_process)
102
+
103
+ # Step 3: Vectorization
104
+ debug_info["steps"].append("3. Vectorization")
105
+ try:
106
+ vect = vectorizer.transform([text_to_process])
107
+ debug_info["vectorized_shape"] = vect.shape
108
+ debug_info["vectorized_nnz"] = vect.nnz
109
+ debug_info["vectorized_dtype"] = str(vect.dtype)
110
+ except Exception as e:
111
+ debug_info["vectorization_error"] = str(e)
112
+ raise HTTPException(status_code=500, detail=f"Vectorization failed: {str(e)}")
113
+
114
+ # Step 4: Model prediction
115
+ debug_info["steps"].append("4. Model prediction")
116
+ try:
117
+ pred = model.predict(vect)
118
+ debug_info["raw_prediction"] = pred.tolist() if hasattr(pred, 'tolist') else str(pred)
119
+ debug_info["prediction_type"] = str(type(pred))
120
+ debug_info["prediction_shape"] = pred.shape if hasattr(pred, 'shape') else "no shape"
121
+ except Exception as e:
122
+ debug_info["prediction_error"] = str(e)
123
+ raise HTTPException(status_code=500, detail=f"Model prediction failed: {str(e)}")
124
+
125
+ # Step 5: Label transformation
126
+ debug_info["steps"].append("5. Label transformation")
127
+ try:
128
+ # Check if prediction is in valid range
129
+ if hasattr(label_encoder, 'classes_'):
130
+ debug_info["available_classes"] = label_encoder.classes_.tolist()
131
+ debug_info["num_classes"] = len(label_encoder.classes_)
132
+
133
+ label = label_encoder.inverse_transform(pred)
134
+ debug_info["final_label"] = label[0] if len(label) > 0 else "no label"
135
+ debug_info["label_type"] = str(type(label[0])) if len(label) > 0 else "no label"
136
+ except Exception as e:
137
+ debug_info["label_transform_error"] = str(e)
138
+ raise HTTPException(status_code=500, detail=f"Label transformation failed: {str(e)}")
139
+
140
+ debug_info["steps"].append("6. Success!")
141
+ debug_info["final_prediction"] = label[0]
142
+
143
+ return debug_info
144
+
145
+ except HTTPException:
146
+ raise
147
+ except Exception as e:
148
+ debug_info["unexpected_error"] = str(e)
149
+ debug_info["traceback"] = traceback.format_exc()
150
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
151
+
152
  @app.post("/predict")
153
  def predict(request: PredictRequest):
154
+ """Production predict endpoint with better error handling"""
155
  if not models_loaded:
156
+ raise HTTPException(status_code=503, detail="Models not loaded")
 
 
 
 
 
 
157
 
158
  try:
159
+ # Input validation
160
+ if not request.text or not request.text.strip():
161
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
162
+
163
+ text_to_process = request.text.strip()
164
+ logger.info(f"Processing text of length: {len(text_to_process)}")
165
 
166
+ # Vectorization with error handling
167
+ try:
168
+ vect = vectorizer.transform([text_to_process])
169
+ logger.info(f"Vectorization successful: shape={vect.shape}, nnz={vect.nnz}")
170
+ except Exception as e:
171
+ logger.error(f"Vectorization error: {str(e)}")
172
+ raise HTTPException(status_code=500, detail="Text vectorization failed")
173
 
174
+ # Prediction with error handling
175
+ try:
176
+ pred = model.predict(vect)
177
+ logger.info(f"Prediction successful: {pred}")
178
+ except Exception as e:
179
+ logger.error(f"Model prediction error: {str(e)}")
180
+ raise HTTPException(status_code=500, detail="Model prediction failed")
181
 
182
+ # Label transformation with error handling
183
+ try:
184
+ # Validate prediction is in expected range
185
+ if hasattr(label_encoder, 'classes_'):
186
+ max_class = len(label_encoder.classes_) - 1
187
+ if np.any(pred < 0) or np.any(pred > max_class):
188
+ logger.error(f"Prediction {pred} out of range [0, {max_class}]")
189
+ raise ValueError(f"Prediction out of range")
190
+
191
+ label = label_encoder.inverse_transform(pred)
192
+ logger.info(f"Label transformation successful: {label[0]}")
193
+ except Exception as e:
194
+ logger.error(f"Label transformation error: {str(e)}")
195
+ raise HTTPException(status_code=500, detail="Label transformation failed")
196
 
197
  return {"prediction": label[0]}
198
 
199
+ except HTTPException:
200
+ raise
201
  except Exception as e:
202
+ logger.error(f"Unexpected error in predict: {str(e)}")
203
+ logger.error(traceback.format_exc())
204
+ raise HTTPException(status_code=500, detail="Internal server error")
 
 
205
 
 
206
  @app.post("/reload-models")
207
  def reload_models():
 
208
  global model, vectorizer, label_encoder, models_loaded
209
 
210
  try: