import uvicorn
import pandas as pd
from typing import Union
from fastapi import FastAPI, Query
import joblib
from enum import Enum
from fastapi.responses import HTMLResponse

description = """
Welcome to the GetAround Car Value Prediction API. This app provides an endpoint to predict car values based on various features! Try it out 🕹️

## Machine Learning

This section includes a Machine Learning endpoint that predicts car values based on various features. Here is the endpoint:

* `/predict`: **POST** request that accepts a list of car features and returns a predicted car value.

Check out the documentation below 👇 for more information on each endpoint. 
"""

tags_metadata = [
    {
        "name": "Machine Learning",
        "description": "Endpoint for predicting car values based on provided features."
    }
]

app = FastAPI(
    title="🚗 GetAround Car Value Prediction API",
    description=description,
    version="0.1",
    contact={
        "name": "Antoine VERDON",
        "email": "antoineverdon.pro@gmail.com",
    },
    openapi_tags=tags_metadata
)

class CarBrand(str, Enum):
    citroen = "Citroën"
    peugeot = "Peugeot"
    pgo = "PGO"
    renault = "Renault"
    audi = "Audi"
    bmw = "BMW"
    other = "other"
    mercedes = "Mercedes"
    opel = "Opel"
    volkswagen = "Volkswagen"
    ferrari = "Ferrari"
    maserati = "Maserati"
    mitsubishi = "Mitsubishi"
    nissan = "Nissan"
    seat = "SEAT"
    subaru = "Subaru"
    toyota = "Toyota"

class FuelType(str, Enum):
    diesel = "diesel"
    petrol = "petrol"
    hybrid_petrol = "hybrid_petrol"
    electro = "electro"

class PaintColor(str, Enum):
    black = "black"
    grey = "grey"
    white = "white"
    red = "red"
    silver = "silver"
    blue = "blue"
    orange = "orange"
    beige = "beige"
    brown = "brown"
    green = "green"

class CarType(str, Enum):
    convertible = "convertible"
    coupe = "coupe"
    estate = "estate"
    hatchback = "hatchback"
    sedan = "sedan"
    subcompact = "subcompact"
    suv = "suv"
    van = "van"

@app.get("/", response_class=HTMLResponse, tags=["Introduction Endpoints"])
async def index():
    return (
        "Hello world! This `/` is the most simple and default endpoint. "
        "If you want to learn more, check out documentation of the API at "
        "<a href='/docs'>/docs</a> or "
        "<a href='https://2nzi-getaroundapi.hf.space/docs' target='_blank'>external docs</a>."
    )

@app.post("/predict", tags=["Machine Learning"])
async def predict(
    brand: CarBrand,
    mileage: int = Query(...),
    engine_power: int = Query(...),
    fuel: FuelType = Query(...),
    paint_color: PaintColor = Query(...),
    car_type: CarType = Query(...),
    private_parking_available: bool = Query(...),
    has_gps: bool = Query(...),
    has_air_conditioning: bool = Query(...),
    automatic_car: bool = Query(...),
    has_getaround_connect: bool = Query(...),
    has_speed_regulator: bool = Query(...),
    winter_tires: bool = Query(...)
):
    
    car_data_dict = {
        'model_key': [brand],
        'mileage': [mileage],
        'engine_power': [engine_power],
        'fuel': [fuel],
        'paint_color': [paint_color],
        'car_type': [car_type],
        'private_parking_available': [private_parking_available],
        'has_gps': [has_gps],
        'has_air_conditioning': [has_air_conditioning],
        'automatic_car': [automatic_car],
        'has_getaround_connect': [has_getaround_connect],
        'has_speed_regulator': [has_speed_regulator],
        'winter_tires': [winter_tires]
    }
    car_data = pd.DataFrame(car_data_dict)

    model = joblib.load('best_model_XGBoost.pkl')
    prediction = model.predict(car_data)
    response = {"prediction": prediction.tolist()[0]}
    return response

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=4000)