from fastapi import FastAPI, Query, Request, HTTPException | |
from fastapi.responses import JSONResponse, HTMLResponse | |
from fastapi.templating import Jinja2Templates | |
import joblib | |
import pandas as pd | |
import numpy as np | |
from pydantic import BaseModel | |
from sklearn.pipeline import Pipeline | |
from sklearn.feature_selection import SelectKBest | |
from sklearn.ensemble import BaggingClassifier | |
from xgboost import XGBClassifier | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.compose import ColumnTransformer | |
from sklearn.feature_selection import f_classif | |
from sklearn.impute import SimpleImputer | |
app = FastAPI() | |
templates = Jinja2Templates(directory="templates") | |
class InputFeatures(BaseModel): | |
Plasma_glucose: float | |
Blood_Work_Result_1: float | |
Blood_Pressure: float | |
Blood_Work_Result_2: float | |
Blood_Work_Result_3: float | |
Body_mass_index: float | |
Blood_Work_Result_4: float | |
patients_age: int | |
# Load the pickled XGBoost model | |
model_input = joblib.load("model_1.joblib") | |
async def predict(input: InputFeatures): | |
# { Plasma_glucose: float, | |
# Blood_Work_Result_1: float, | |
# Blood_Pressure: float, | |
# Blood_Work_Result_2: float, | |
# Blood_Work_Result_3: float, | |
# Body_mass_index: float, | |
# Blood_Work_Result_4: float, | |
# patients_age: int | |
# } | |
# return input | |
# Numeric Features | |
num_attr = [ | |
['Plasma_glucose', 'Blood_Work_Result_1', 'Blood_Pressure', | |
'Blood_Work_Result_2', 'Blood_Work_Result_3', 'Body_mass_index', | |
'Blood_Work_Result_4', 'patients_age']] | |
#creating pipelines | |
num_pipeline= Pipeline([('imputer', SimpleImputer()),('scaler', StandardScaler())]) | |
full_pipeline=ColumnTransformer([('num_pipe',num_pipeline,num_attr)]) | |
df = pd.DataFrame([input]) | |
full_pipeline.fit(df) # Fit the full pipeline on the DataFrame | |
final_input = np.array(full_pipeline.transform(df), dtype=np.str) | |
prediction = model_input.predict(final_input) | |
XGB = Pipeline([ | |
("col_trans", full_pipeline), # You need to define full_pipeline | |
("feature_selection", SelectKBest(score_func=f_classif, k='all')), | |
("model", BaggingClassifier(base_estimator=XGBClassifier(random_state=42))) | |
]) | |
# df = pd.DataFrame([input]) | |
# final_input = np.array(full_pipeline.fit_transform(df), dtype=np.str) # Check predict_input, maybe it should be XGB | |
# #prediction = model_input.predict(np.array([final_input]).reshape(1, -1)) | |
# prediction = model_input.predict(final_input) | |
# full_pipeline = ColumnTransformer([('num_pipe', num_pipeline, num_attr)]) | |
# full_pipeline.fit(df) # Fit the full pipeline on the DataFrame | |
# #final_input = np.array(full_pipeline.transform(df), dtype=np.str) | |
#print(model_input) | |
#df = pd.DataFrame([input]) | |
#final_input = np.array(predict_input.fit_transform(df), dtype = np.str) | |
return prediction | |
if __name__ == '__main__': | |
uvicorn.run("Main:app", reload=True) | |
# @app.get("/") | |
# async def read_root(): | |
# return {"message": "Welcome to the Sepsis Prediction API."} | |
# @app.get("/form/") | |
# async def show_form(): | |
# @app.post("/predict/") | |
# async def predict_sepsis( | |
# request: Request, | |
# input_data: InputFeatures # Use the Pydantic model for input validation | |
# ): | |
# try: | |
# # Convert Pydantic model to a DataFrame for prediction | |
# input_df = pd.DataFrame([input_data.dict()]) | |
# # Make predictions using the loaded XGBoost model | |
# prediction = xgb_model.predict_proba(xgb.DMatrix(input_df)) | |
# # Create a JSON response | |
# response = { | |
# "input_features": input_data, | |
# "prediction": { | |
# "class_0_probability": prediction[0], | |
# "class_1_probability": prediction[1] | |
# } | |
# } | |
# return templates.TemplateResponse( | |
# "display_params.html", | |
# { | |
# "request": request, | |
# "input_features": response["input_features"], | |
# "prediction": response["prediction"] | |
# } | |
# ) | |
# except Exception as e: | |
# #raise HTTPException(status_code=500, detail="An error occurred while processing the request.") | |
# raise HTTPException(status_code=500, detail=str(e)) | |