Jakaria commited on
Commit
ad7a9e2
·
1 Parent(s): cdd0b85

Add Bangla model API

Browse files
Files changed (1) hide show
  1. app.py +131 -10
app.py CHANGED
@@ -1,24 +1,145 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  import joblib
 
 
 
 
 
 
4
 
5
  app = FastAPI()
6
 
7
  class PredictRequest(BaseModel):
8
  text: str
9
 
10
- # Load models
11
- model = joblib.load("bangla_model.pkl")
12
- vectorizer = joblib.load("bangla_vectorizer.pkl")
13
- label_encoder = joblib.load("bangla_label_encoder.pkl")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  @app.get("/")
16
  def root():
17
- return {"message": "Bangla model API is running!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  @app.post("/predict")
20
  def predict(request: PredictRequest):
21
- vect = vectorizer.transform([request.text])
22
- pred = model.predict(vect)
23
- label = label_encoder.inverse_transform(pred)
24
- return {"prediction": label[0]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import joblib
4
+ import os
5
+ import pickle
6
+ import logging
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
 
11
  app = FastAPI()
12
 
13
  class PredictRequest(BaseModel):
14
  text: str
15
 
16
+ # Global model variables
17
+ model = None
18
+ vectorizer = None
19
+ label_encoder = None
20
+ models_loaded = False
21
+
22
+ def load_model_safe(filename):
23
+ """Safely load model with multiple methods"""
24
+ if not os.path.exists(filename):
25
+ raise FileNotFoundError(f"{filename} not found")
26
+
27
+ # Try joblib first
28
+ try:
29
+ return joblib.load(filename)
30
+ except Exception as e1:
31
+ logger.warning(f"Joblib failed for {filename}: {e1}")
32
+
33
+ # Try pickle as fallback
34
+ try:
35
+ with open(filename, 'rb') as f:
36
+ return pickle.load(f)
37
+ except Exception as e2:
38
+ logger.error(f"Pickle also failed for {filename}: {e2}")
39
+ raise e1 # Raise original joblib error
40
+
41
+ @app.on_event("startup")
42
+ async def startup_event():
43
+ global model, vectorizer, label_encoder, models_loaded
44
+
45
+ try:
46
+ logger.info("Starting model loading...")
47
+
48
+ # Load each model individually with error handling
49
+ logger.info("Loading bangla_model.pkl...")
50
+ model = load_model_safe("bangla_model.pkl")
51
+ logger.info(f"Model type: {type(model)}")
52
+
53
+ logger.info("Loading bangla_vectorizer.pkl...")
54
+ vectorizer = load_model_safe("bangla_vectorizer.pkl")
55
+ logger.info(f"Vectorizer type: {type(vectorizer)}")
56
+
57
+ logger.info("Loading bangla_label_encoder.pkl...")
58
+ label_encoder = load_model_safe("bangla_label_encoder.pkl")
59
+ logger.info(f"Label encoder type: {type(label_encoder)}")
60
+
61
+ # Test pipeline with dummy data
62
+ logger.info("Testing pipeline...")
63
+ test_vect = vectorizer.transform(["test"])
64
+ test_pred = model.predict(test_vect)
65
+ test_label = label_encoder.inverse_transform(test_pred)
66
+ logger.info(f"Pipeline test successful: {test_label[0]}")
67
+
68
+ models_loaded = True
69
+ logger.info("All models loaded successfully!")
70
+
71
+ except Exception as e:
72
+ logger.error(f"Failed to load models: {str(e)}")
73
+ models_loaded = False
74
+ # Don't raise here - let the app start and handle errors in endpoints
75
 
76
  @app.get("/")
77
  def root():
78
+ return {
79
+ "message": "Bangla model API is running!",
80
+ "models_loaded": models_loaded,
81
+ "status": "healthy" if models_loaded else "models_not_loaded"
82
+ }
83
+
84
+ @app.get("/status")
85
+ def status():
86
+ """Detailed status endpoint"""
87
+ return {
88
+ "models_loaded": models_loaded,
89
+ "model_available": model is not None,
90
+ "vectorizer_available": vectorizer is not None,
91
+ "label_encoder_available": label_encoder is not None,
92
+ "current_directory": os.getcwd(),
93
+ "available_files": [f for f in os.listdir('.') if f.endswith('.pkl')]
94
+ }
95
 
96
  @app.post("/predict")
97
  def predict(request: PredictRequest):
98
+ if not models_loaded:
99
+ raise HTTPException(
100
+ status_code=503,
101
+ detail="Models not loaded. Check /status endpoint for details."
102
+ )
103
+
104
+ if not request.text or not request.text.strip():
105
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
106
+
107
+ try:
108
+ logger.info(f"Processing text: {request.text[:50]}...")
109
+
110
+ # Transform text
111
+ vect = vectorizer.transform([request.text])
112
+ logger.info(f"Vectorization successful, shape: {vect.shape}")
113
+
114
+ # Make prediction
115
+ pred = model.predict(vect)
116
+ logger.info(f"Prediction successful: {pred}")
117
+
118
+ # Transform label
119
+ label = label_encoder.inverse_transform(pred)
120
+ logger.info(f"Label transformation successful: {label[0]}")
121
+
122
+ return {"prediction": label[0]}
123
+
124
+ except Exception as e:
125
+ logger.error(f"Prediction error: {str(e)}")
126
+ raise HTTPException(
127
+ status_code=500,
128
+ detail=f"Prediction failed: {str(e)}"
129
+ )
130
+
131
+ # Add a manual model reload endpoint for debugging
132
+ @app.post("/reload-models")
133
+ def reload_models():
134
+ """Manually reload models - useful for debugging"""
135
+ global model, vectorizer, label_encoder, models_loaded
136
+
137
+ try:
138
+ model = load_model_safe("bangla_model.pkl")
139
+ vectorizer = load_model_safe("bangla_vectorizer.pkl")
140
+ label_encoder = load_model_safe("bangla_label_encoder.pkl")
141
+ models_loaded = True
142
+ return {"message": "Models reloaded successfully"}
143
+ except Exception as e:
144
+ models_loaded = False
145
+ raise HTTPException(status_code=500, detail=f"Failed to reload models: {str(e)}")