# main.py | |
from fastapi import FastAPI, Query, Request, HTTPException | |
from fastapi.responses import JSONResponse, HTMLResponse | |
from fastapi.templating import Jinja2Templates | |
#import xgboost as xgb | |
import joblib | |
import pandas as pd | |
from pydantic import BaseModel # Import Pydantic's BaseModel | |
app = FastAPI() | |
templates = Jinja2Templates(directory="templates") | |
class InputFeatures(BaseModel): | |
prg: float | |
pl: float | |
pr: float | |
sk: float | |
ts: float | |
m11: float | |
bd2: float | |
age: int | |
# Load the pickled XGBoost model | |
model_input = joblib.load("model_1.joblib") | |
async def predicts(input:model_input): | |
XGB= Pipeline([ | |
("col_trans", full_pipeline), | |
("feature_selection", SelectKBest(score_func=f_classif, k='all')), | |
("model", BaggingClassifier(base_estimator=XGBClassifier(random_state=42))) | |
]) | |
return prediction | |
# @app.get("/") | |
# async def read_root(): | |
# return {"message": "Welcome to the Sepsis Prediction API"} | |
# @app.get("/form/") | |
# async def show_form(): | |
# @app.post("/predict/") | |
# async def predict_sepsis( | |
# request: Request, | |
# input_data: InputFeatures # Use the Pydantic model for input validation | |
# ): | |
# try: | |
# # Convert Pydantic model to a DataFrame for prediction | |
# input_df = pd.DataFrame([input_data.dict()]) | |
# # Make predictions using the loaded XGBoost model | |
# prediction = xgb_model.predict_proba(xgb.DMatrix(input_df)) | |
# # Create a JSON response | |
# response = { | |
# "input_features": input_data, | |
# "prediction": { | |
# "class_0_probability": prediction[0], | |
# "class_1_probability": prediction[1] | |
# } | |
# } | |
# return templates.TemplateResponse( | |
# "display_params.html", | |
# { | |
# "request": request, | |
# "input_features": response["input_features"], | |
# "prediction": response["prediction"] | |
# } | |
# ) | |
# except Exception as e: | |
# #raise HTTPException(status_code=500, detail="An error occurred while processing the request.") | |
# raise HTTPException(status_code=500, detail=str(e)) | |