File size: 3,424 Bytes
f349339
 
5416df1
f349339
 
d6bc2ab
8c53a41
 
 
 
 
 
 
 
263b291
5416df1
206327a
d6bc2ab
f349339
 
5416df1
f349339
511693e
8c53a41
 
 
 
 
 
 
56271c1
511a161
6b17ed9
 
69c141d
8c53a41
 
206327a
8c53a41
 
206327a
 
 
bfd5973
206327a
 
bfd5973
8c53a41
 
 
 
 
 
56271c1
 
8c53a41
 
56271c1
 
 
bfd5973
 
6b17ed9
 
 
511a161
 
955fee0
 
 
f349339
955fee0
 
f349339
6b17ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from fastapi import FastAPI, Query, Request, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.templating import Jinja2Templates
import joblib
import pandas as pd
import numpy as np
from pydantic import BaseModel
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import SelectKBest
from sklearn.ensemble import BaggingClassifier
from xgboost import XGBClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.feature_selection import f_classif
from sklearn.impute import SimpleImputer 



app = FastAPI()
templates = Jinja2Templates(directory="templates")

class InputFeatures(BaseModel):
    Plasma_glucose: float
    Blood_Work_Result_1: float
    Blood_Pressure: float
    Blood_Work_Result_2: float
    Blood_Work_Result_3: float
    Body_mass_index: float
    Blood_Work_Result_4: float
    patients_age: int

# Load the pickled XGBoost model
model_input = joblib.load("model_1.joblib")

@app.post("/sepsis_prediction")
async def predict(input: InputFeatures):  
    # Numeric Features 
    num_attr = [
        ['Plasma_glucose', 'Blood_Work_Result_1', 'Blood_Pressure',
         'Blood_Work_Result_2', 'Blood_Work_Result_3', 'Body_mass_index',
         'Blood_Work_Result_4', 'patients_age']]

    #creating pipelines
    num_pipeline= Pipeline([('imputer', SimpleImputer()),('scaler', StandardScaler())])


    full_pipeline=ColumnTransformer([('num_pipe',num_pipeline,num_attr)])

    XGB = Pipeline([
        ("col_trans", full_pipeline),  # You need to define full_pipeline
        ("feature_selection", SelectKBest(score_func=f_classif, k='all')),
        ("model", BaggingClassifier(base_estimator=XGBClassifier(random_state=42)))
    ])

    df = pd.DataFrame([input])
    final_input = np.array(predict_input.fit_transform(df), dtype=np.str)  # Check predict_input, maybe it should be XGB
    prediction = model_input.predict(np.array([final_input]).reshape(1, -1))

    return prediction

if __name__ == '__main__':
     uvicorn.run("Main:app", reload=True)





# @app.get("/")
# async def read_root():
#     return {"message": "Welcome to the Sepsis Prediction API"}

# @app.get("/form/")
# async def show_form():

# @app.post("/predict/")
# async def predict_sepsis(
#     request: Request,
#     input_data: InputFeatures  # Use the Pydantic model for input validation
# ):
#     try:
#         # Convert Pydantic model to a DataFrame for prediction
#         input_df = pd.DataFrame([input_data.dict()])

#         # Make predictions using the loaded XGBoost model
#         prediction = xgb_model.predict_proba(xgb.DMatrix(input_df))

#         # Create a JSON response
#         response = {
#             "input_features": input_data,
#             "prediction": {
#                 "class_0_probability": prediction[0],
#                 "class_1_probability": prediction[1]
#             }
#         }

#         return templates.TemplateResponse(
#             "display_params.html",
#             {
#                 "request": request,
#                 "input_features": response["input_features"],
#                 "prediction": response["prediction"]
#             }
#         )
#     except Exception as e:
#         #raise HTTPException(status_code=500, detail="An error occurred while processing the request.")
#         raise HTTPException(status_code=500, detail=str(e))