File size: 3,017 Bytes
f4bb7d1
15a293c
f4bb7d1
 
 
 
15a293c
 
f4bb7d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a293c
f4bb7d1
 
 
 
15a293c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4bb7d1
 
15a293c
 
 
 
 
 
 
 
 
 
 
 
 
f4bb7d1
15a293c
f4bb7d1
d4fa0c5
15a293c
 
 
 
d4fa0c5
f4bb7d1
 
 
 
15a293c
f4bb7d1
 
 
 
 
15a293c
f4bb7d1
 
 
 
 
 
 
15a293c
f4bb7d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import uvicorn
import pandas as pd
from pydantic import BaseModel
from typing import List, Union
from fastapi import FastAPI
import joblib
from enum import Enum
from fastapi.responses import HTMLResponse

description = """
Welcome to the GetAround Car Value Prediction API. This app provides an endpoint to predict car values based on various features! Try it out 🕹️

## Machine Learning

This section includes a Machine Learning endpoint that predicts car values based on various features. Here is the endpoint:

* `/predict`: **POST** request that accepts a list of car features and returns a predicted car value.

Check out the documentation below 👇 for more information on each endpoint. 
"""

tags_metadata = [
    {
        "name": "Machine Learning",
        "description": "Endpoint for predicting car values based on provided features."
    }
]

app = FastAPI(
    title="🚗 GetAround Car Value Prediction API",
    description=description,
    version="0.1",
    contact={
        "name": "Antoine VERDON",
        "email": "[email protected]",
    },
    openapi_tags=tags_metadata
)

class CarBrand(str, Enum):
    citroen = "Citroën"
    peugeot = "Peugeot"
    pgo = "PGO"
    renault = "Renault"
    audi = "Audi"
    bmw = "BMW"
    other = "other"
    mercedes = "Mercedes"
    opel = "Opel"
    volkswagen = "Volkswagen"
    ferrari = "Ferrari"
    maserati = "Maserati"
    mitsubishi = "Mitsubishi"
    nissan = "Nissan"
    seat = "SEAT"
    subaru = "Subaru"
    toyota = "Toyota"

class PredictionFeatures(BaseModel):
    brand: CarBrand
    mileage: int
    engine_power: int
    fuel: str
    paint_color: str
    car_type: str
    private_parking_available: bool
    has_gps: bool
    has_air_conditioning: bool
    automatic_car: bool
    has_getaround_connect: bool
    has_speed_regulator: bool
    winter_tires: bool

@app.get("/", response_class=HTMLResponse, tags=["Introduction Endpoints"])
async def index():
    return (
        "Hello world! This `/` is the most simple and default endpoint. "
        "If you want to learn more, check out documentation of the API at "
        "<a href='/docs'>/docs</a> or "
        "<a href='https://2nzi-getaroundapi.hf.space/docs' target='_blank'>external docs</a>."
    )

@app.post("/predict", tags=["Machine Learning"])
async def predict(predictionFeatures: PredictionFeatures):
    columns = [
        'brand', 'mileage', 'engine_power', 'fuel', 'paint_color',
        'car_type', 'private_parking_available', 'has_gps',
        'has_air_conditioning', 'automatic_car', 'has_getaround_connect',
        'has_speed_regulator', 'winter_tires'
    ]
    
    car_data_dict = {col: [getattr(predictionFeatures, col)] for col in columns}
    car_data = pd.DataFrame(car_data_dict)

    model = joblib.load('best_model_XGBoost.pkl')
    prediction = model.predict(car_data)
    response = {"prediction": prediction.tolist()[0]}
    return response

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=4000)