Spaces:
Running
Running
File size: 3,882 Bytes
b0dd811 15a293c b0dd811 15a293c f4bb7d1 b0dd811 f4bb7d1 b0dd811 f4bb7d1 b0dd811 f4bb7d1 b0dd811 f4bb7d1 b0dd811 f4bb7d1 b0dd811 f4bb7d1 b0dd811 967b935 b0dd811 f4bb7d1 d4fa0c5 b0dd811 d4fa0c5 f4bb7d1 b0dd811 f4bb7d1 b0dd811 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import uvicorn
import pandas as pd
from typing import Union
from fastapi import FastAPI, Query
import joblib
from enum import Enum
from fastapi.responses import HTMLResponse
description = """
Welcome to the GetAround Car Value Prediction API. This app provides an endpoint to predict car values based on various features! Try it out 🕹️
## Machine Learning
This section includes a Machine Learning endpoint that predicts car values based on various features. Here is the endpoint:
* `/predict`: **POST** request that accepts a list of car features and returns a predicted car value.
Check out the documentation below 👇 for more information on each endpoint.
"""
tags_metadata = [
{
"name": "Machine Learning",
"description": "Endpoint for predicting car values based on provided features."
}
]
app = FastAPI(
title="🚗 GetAround Car Value Prediction API",
description=description,
version="0.1",
contact={
"name": "Antoine VERDON",
"email": "antoineverdon.pro@gmail.com",
},
openapi_tags=tags_metadata
)
class CarBrand(str, Enum):
citroen = "Citroën"
peugeot = "Peugeot"
pgo = "PGO"
renault = "Renault"
audi = "Audi"
bmw = "BMW"
other = "other"
mercedes = "Mercedes"
opel = "Opel"
volkswagen = "Volkswagen"
ferrari = "Ferrari"
maserati = "Maserati"
mitsubishi = "Mitsubishi"
nissan = "Nissan"
seat = "SEAT"
subaru = "Subaru"
toyota = "Toyota"
class FuelType(str, Enum):
diesel = "diesel"
petrol = "petrol"
hybrid_petrol = "hybrid_petrol"
electro = "electro"
class PaintColor(str, Enum):
black = "black"
grey = "grey"
white = "white"
red = "red"
silver = "silver"
blue = "blue"
orange = "orange"
beige = "beige"
brown = "brown"
green = "green"
class CarType(str, Enum):
convertible = "convertible"
coupe = "coupe"
estate = "estate"
hatchback = "hatchback"
sedan = "sedan"
subcompact = "subcompact"
suv = "suv"
van = "van"
@app.get("/", response_class=HTMLResponse, tags=["Introduction Endpoints"])
async def index():
return (
"Hello world! This `/` is the most simple and default endpoint. "
"If you want to learn more, check out documentation of the API at "
"<a href='/docs'>/docs</a> or "
"<a href='https://2nzi-getaroundapi.hf.space/docs' target='_blank'>external docs</a>."
)
@app.post("/predict", tags=["Machine Learning"])
async def predict(
brand: CarBrand,
mileage: int = Query(...),
engine_power: int = Query(...),
fuel: FuelType = Query(...),
paint_color: PaintColor = Query(...),
car_type: CarType = Query(...),
private_parking_available: bool = Query(...),
has_gps: bool = Query(...),
has_air_conditioning: bool = Query(...),
automatic_car: bool = Query(...),
has_getaround_connect: bool = Query(...),
has_speed_regulator: bool = Query(...),
winter_tires: bool = Query(...)
):
car_data_dict = {
'model_key': [brand],
'mileage': [mileage],
'engine_power': [engine_power],
'fuel': [fuel],
'paint_color': [paint_color],
'car_type': [car_type],
'private_parking_available': [private_parking_available],
'has_gps': [has_gps],
'has_air_conditioning': [has_air_conditioning],
'automatic_car': [automatic_car],
'has_getaround_connect': [has_getaround_connect],
'has_speed_regulator': [has_speed_regulator],
'winter_tires': [winter_tires]
}
car_data = pd.DataFrame(car_data_dict)
model = joblib.load('best_model_XGBoost.pkl')
prediction = model.predict(car_data)
response = {"prediction": prediction.tolist()[0]}
return response
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=4000)
|